Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bin/pytorch_inference/CSupportedOperations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA
// deepset/tinyroberta-squad2, typeform/squeezebert-mnli,
// facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6,
// distilbert-base-uncased-finetuned-sst-2-english,
// sentence-transformers/all-distilroberta-v1.
// sentence-transformers/all-distilroberta-v1,
// jinaai/jina-embeddings-v5-text-nano (EuroBERT + LoRA).
// Eland-deployed variants of the above models (with pooling/normalization layers).
// Additional ops from Elasticsearch integration test models
// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT).
Expand All @@ -68,6 +69,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::clone"sv,
"aten::contiguous"sv,
"aten::copy_"sv,
"aten::cos"sv,
"aten::cumsum"sv,
"aten::detach"sv,
"aten::div"sv,
Expand Down Expand Up @@ -117,10 +119,13 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::relu"sv,
"aten::repeat"sv,
"aten::reshape"sv,
"aten::rsqrt"sv,
"aten::rsub"sv,
"aten::scaled_dot_product_attention"sv,
"aten::select"sv,
"aten::sign"sv,
"aten::silu"sv,
"aten::sin"sv,
"aten::size"sv,
"aten::slice"sv,
"aten::softmax"sv,
Expand Down
40 changes: 22 additions & 18 deletions bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,25 @@ BOOST_AUTO_TEST_CASE(testValidModuleWithAllowedOps) {
}

BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) {
// torch.sin is not in the transformer allowlist.
// torch.logit is not in the transformer allowlist.
::torch::jit::Module m("__torch__.UnknownOps");
m.define(R"(
def forward(self, x: Tensor) -> Tensor:
return torch.sin(x)
return torch.logit(x)
)");

auto result = CModelGraphValidator::validate(m);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(result.s_ForbiddenOps.empty());
BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false);
bool foundSin = false;
bool foundLogit = false;
for (const auto& op : result.s_UnrecognisedOps) {
if (op == "aten::sin") {
foundSin = true;
if (op == "aten::logit") {
foundLogit = true;
}
}
BOOST_REQUIRE(foundSin);
BOOST_REQUIRE(foundLogit);
}

BOOST_AUTO_TEST_CASE(testModuleNodeCountPopulated) {
Expand All @@ -301,7 +301,7 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) {
::torch::jit::Module child("__torch__.Child");
child.define(R"(
def forward(self, x: Tensor) -> Tensor:
return torch.sin(x)
return torch.logit(x)
)");

::torch::jit::Module parent("__torch__.Parent");
Expand All @@ -314,19 +314,19 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) {
auto result = CModelGraphValidator::validate(parent);

BOOST_REQUIRE(result.s_IsValid == false);
bool foundSin = false;
bool foundLogit = false;
for (const auto& op : result.s_UnrecognisedOps) {
if (op == "aten::sin") {
foundSin = true;
if (op == "aten::logit") {
foundLogit = true;
}
}
BOOST_REQUIRE(foundSin);
BOOST_REQUIRE(foundLogit);
}

// --- Integration tests with malicious .pt model fixtures ---
//
// These load real TorchScript models that simulate attack vectors.
// The .pt files are generated by testfiles/generate_malicious_models.py.
// The .pt files are generated by dev-tools/generate_malicious_models.py.

namespace {
bool hasForbiddenOp(const CModelGraphValidator::SResult& result, const std::string& op) {
Expand Down Expand Up @@ -363,34 +363,38 @@ BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) {
BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) {
// Unrecognised ops buried three levels deep in nested submodules.
// The validator must inline through all submodules to find them.
// The leaf uses aten::logit (still unrecognised) so the fixture stays
// invalid when aten::sin is allowed for EuroBERT/Jina v5.
auto module = ::torch::jit::load("testfiles/malicious_models/malicious_hidden_in_submodule.pt");
auto result = CModelGraphValidator::validate(module);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(result.s_ForbiddenOps.empty());
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin"));
BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false);
}

BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) {
// An unrecognised op hidden inside a conditional branch. The
// validator must recurse into prim::If blocks to detect it.
// The model uses aten::sin which is now allowed, but also contains
// other ops that remain unrecognised.
auto module = ::torch::jit::load("testfiles/malicious_models/malicious_conditional.pt");
auto result = CModelGraphValidator::validate(module);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin"));
BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false);
}

BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) {
// A model using many different unrecognised ops (sin, cos, tan, exp).
// A model using many different ops (sin, cos, tan, exp).
// sin and cos are now allowed (EuroBERT/Jina v5), but tan and exp
// remain unrecognised.
auto module = ::torch::jit::load("testfiles/malicious_models/malicious_many_unrecognised.pt");
auto result = CModelGraphValidator::validate(module);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(result.s_ForbiddenOps.empty());
BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 4);
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin"));
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::cos"));
BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 2);
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::tan"));
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::exp"));
}
Expand Down
Binary file not shown.
42 changes: 42 additions & 0 deletions bin/pytorch_inference/unittest/testfiles/reference_model_ops.json
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,48 @@
"prim::NumToTensor"
]
},
"jina-embeddings-v5-text-nano": {
"model_id": "jinaai/jina-embeddings-v5-text-nano",
"quantized": false,
"ops": [
"aten::Int",
"aten::add",
"aten::arange",
"aten::cat",
"aten::contiguous",
"aten::cos",
"aten::detach",
"aten::dropout",
"aten::embedding",
"aten::expand",
"aten::floor_divide",
"aten::linear",
"aten::masked_fill",
"aten::matmul",
"aten::mean",
"aten::mul",
"aten::neg",
"aten::pow",
"aten::reshape",
"aten::rsqrt",
"aten::scaled_dot_product_attention",
"aten::silu",
"aten::sin",
"aten::size",
"aten::slice",
"aten::sub",
"aten::to",
"aten::transpose",
"aten::unsqueeze",
"aten::view",
"prim::Constant",
"prim::GetAttr",
"prim::ListConstruct",
"prim::NumToTensor",
"prim::TupleConstruct",
"prim::TupleUnpack"
]
},
"qa-tinyroberta-squad2": {
"model_id": "deepset/tinyroberta-squad2",
"quantized": false,
Expand Down
2 changes: 2 additions & 0 deletions dev-tools/extract_model_ops/reference_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true},
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true},

"jina-embeddings-v5-text-nano": "jinaai/jina-embeddings-v5-text-nano",

"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.",
"qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"},
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
Expand Down
9 changes: 6 additions & 3 deletions dev-tools/extract_model_ops/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ def load_and_trace_hf_model(model_name: str, quantize: bool = False,
overrides = config_overrides or {}

try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
tokenizer = AutoTokenizer.from_pretrained(
model_name, token=token, trust_remote_code=True)
config = AutoConfig.from_pretrained(
model_name, torchscript=True, token=token, **overrides)
model_name, torchscript=True, token=token,
trust_remote_code=True, **overrides)
model = model_cls.from_pretrained(
model_name, config=config, token=token)
model_name, config=config, token=token,
trust_remote_code=True)
model.eval()
except Exception as exc:
print(f" LOAD ERROR: {exc}", file=sys.stderr)
Expand Down
2 changes: 2 additions & 0 deletions dev-tools/extract_model_ops/validation_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
"es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base",

"jina-embeddings-v5-text-nano": "jinaai/jina-embeddings-v5-text-nano",

"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.",
"qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"},
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
Expand Down
8 changes: 6 additions & 2 deletions dev-tools/generate_malicious_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def forward(self, x: Tensor) -> Tensor:


class HiddenInSubmodule(torch.nn.Module):
"""Hides aten::sin (unrecognised) three levels deep in submodules."""
"""Hides aten::logit (unrecognised) three levels deep in submodules.

Uses logit+clamp instead of sin so the fixture stays invalid when
aten::sin is added to the allowlist for transformer models (e.g. EuroBERT).
"""
def __init__(self):
super().__init__()
self.inner = _Inner()
Expand All @@ -69,7 +73,7 @@ def forward(self, x: Tensor) -> Tensor:

class _Leaf(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
return torch.sin(x)
return torch.logit(torch.clamp(x, 1e-6, 1.0 - 1e-6))


class ConditionalMalicious(torch.nn.Module):
Expand Down