diff --git a/src/passes/InstrumentBranchHints.cpp b/src/passes/InstrumentBranchHints.cpp index c552ddae891..9a80598877f 100644 --- a/src/passes/InstrumentBranchHints.cpp +++ b/src/passes/InstrumentBranchHints.cpp @@ -98,6 +98,7 @@ #include "ir/effects.h" #include "ir/names.h" +#include "ir/properties.h" #include "ir/utils.h" #include "pass.h" #include "support/string.h" @@ -231,18 +232,6 @@ struct InstrumentationProcessor : public WalkerPass> { Super::doWalkModule(module); } - - // Helpers - - // Check if an expression's condition is instrumented, and return the - // instrumentation call if so. Otherwise return null. - Call* getInstrumentation(Expression* condition) { - auto* call = condition->dynCast(); - if (!call || call->target != logBranch) { - return nullptr; - } - return call; - } }; struct DeleteBranchHints : public InstrumentationProcessor { @@ -251,15 +240,27 @@ struct DeleteBranchHints : public InstrumentationProcessor { // The set of IDs to delete. std::unordered_set idsToDelete; + std::optional getBranchID(Expression* condition, + const PassOptions& passOptions, + Module& wasm) { + auto* call = + Properties::getFallthrough(condition, getPassOptions(), *getModule()) + ->dynCast(); + if (!call || call->target != logBranch || call->operands.size() != 3) { + return std::nullopt; + } + auto* c = call->operands[2]->dynCast(); + if (!c || c->type != Type::i32) { + return std::nullopt; + } + return c->value.geti32(); + } + template void processCondition(T* curr) { - if (auto* call = getInstrumentation(curr->condition)) { - if (auto* c = call->operands[2]->template dynCast()) { - auto id = c->value.geti32(); - if (idsToDelete.contains(id)) { - // Remove the branch hint. - getFunction()->codeAnnotations[curr].branchLikely = {}; - } - } + if (auto id = getBranchID(curr->condition, getPassOptions(), *getModule()); + id && idsToDelete.contains(*id)) { + // Remove the branch hint. + getFunction()->codeAnnotations[curr].branchLikely = std::nullopt; } } diff --git a/test/lit/passes/delete-branch-hints.wast b/test/lit/passes/delete-branch-hints.wast index bbc59604a53..c31220fc246 100644 --- a/test/lit/passes/delete-branch-hints.wast +++ b/test/lit/passes/delete-branch-hints.wast @@ -7,6 +7,8 @@ ;; CHECK: (type $1 (func (param i32 i32 i32) (result i32))) + ;; CHECK: (type $2 (func (param i32) (result i32))) + ;; CHECK: (import "fuzzing-support" "log-branch" (func $log (type $1) (param i32 i32 i32) (result i32))) (import "fuzzing-support" "log-branch" (func $log (param i32 i32 i32) (result i32))) @@ -118,4 +120,41 @@ ) ) ) + + ;; CHECK: (func $stacky (type $2) (param $c i32) (result i32) + ;; CHECK-NEXT: (block $l (result i32) + ;; CHECK-NEXT: (br_if $l + ;; CHECK-NEXT: (i32.const 42) + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (nop) + ;; CHECK-NEXT: (call $log + ;; CHECK-NEXT: (local.get $c) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: (i32.const 10) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $stacky (param $c i32) (result i32) + block $l (result i32) + i32.const 42 + ;; Because the parser greedily pulls previous none-typed expressions into + ;; block, this condition will be parsed as this: + ;; + ;; (block + ;; (nop) + ;; (call $log-branch ...)) + ;; ) + ;; + ;; We must be able to find and handle this pattern to remove the hint. + nop + local.get $c + i32.const 1 + i32.const 10 + call $log + (@metadata.code.branch_hint "\01") + br_if $l + end + ) )