Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bin/pytorch_inference/CCmdLineParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ bool CCmdLineParser::parse(int argc,
std::size_t& cacheMemorylimitBytes,
bool& validElasticLicenseKeyConfirmed,
bool& lowPriority,
bool& useImmediateExecutor) {
bool& useImmediateExecutor,
bool& skipModelValidation) {
try {
boost::program_options::options_description desc(DESCRIPTION);
// clang-format off
Expand Down Expand Up @@ -75,6 +76,7 @@ bool CCmdLineParser::parse(int argc,
("lowPriority", "Execute process in low priority")
("useImmediateExecutor", "Execute requests on the main thread. This mode should only used for "
"benchmarking purposes to ensure requests are processed in order)")
("skipModelValidation", "Skip TorchScript model graph validation. WARNING: disables security checks on model operations.")
;
// clang-format on

Expand Down Expand Up @@ -148,6 +150,9 @@ bool CCmdLineParser::parse(int argc,
return false;
}
}
if (vm.count("skipModelValidation") > 0) {
skipModelValidation = true;
}
} catch (std::exception& e) {
std::cerr << "Error processing command line: " << e.what() << std::endl;
return false;
Expand Down
3 changes: 2 additions & 1 deletion bin/pytorch_inference/CCmdLineParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class CCmdLineParser {
std::size_t& cacheMemorylimitBytes,
bool& validElasticLicenseKeyConfirmed,
bool& lowPriority,
bool& useImmediateExecutor);
bool& useImmediateExecutor,
bool& skipModelValidation);

private:
static const std::string DESCRIPTION;
Expand Down
16 changes: 11 additions & 5 deletions bin/pytorch_inference/Main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ int main(int argc, char** argv) {
bool validElasticLicenseKeyConfirmed{false};
bool lowPriority{false};
bool useImmediateExecutor{false};
bool skipModelValidation{false};

if (ml::torch::CCmdLineParser::parse(
argc, argv, modelId, namedPipeConnectTimeout, inputFileName,
isInputFileNamedPipe, outputFileName, isOutputFileNamedPipe,
restoreFileName, isRestoreFileNamedPipe, logFileName, logProperties,
numThreadsPerAllocation, numAllocations, cacheMemorylimitBytes,
validElasticLicenseKeyConfirmed, lowPriority, useImmediateExecutor) == false) {
isInputFileNamedPipe, outputFileName, isOutputFileNamedPipe, restoreFileName,
isRestoreFileNamedPipe, logFileName, logProperties, numThreadsPerAllocation,
numAllocations, cacheMemorylimitBytes, validElasticLicenseKeyConfirmed,
lowPriority, useImmediateExecutor, skipModelValidation) == false) {
return EXIT_FAILURE;
}

Expand Down Expand Up @@ -315,7 +316,12 @@ int main(int argc, char** argv) {
return EXIT_FAILURE;
}
module_ = torch::jit::load(std::move(readAdapter));
verifySafeModel(module_);
if (skipModelValidation) {
LOG_WARN(<< "Model graph validation SKIPPED — --skipModelValidation flag is set. "
<< "This disables security checks on model operations.");
} else {
verifySafeModel(module_);
}
module_.eval();

LOG_DEBUG(<< "model loaded");
Expand Down
Loading