Skip to content

Fix local_sqrt_sqr rewrite logic bug#1952

Open
WHOIM1205 wants to merge 3 commits intopymc-devs:mainfrom
WHOIM1205:fix-local-sqrt-sqr-math-rewrite
Open

Fix local_sqrt_sqr rewrite logic bug#1952
WHOIM1205 wants to merge 3 commits intopymc-devs:mainfrom
WHOIM1205:fix-local-sqrt-sqr-math-rewrite

Conversation

@WHOIM1205
Copy link
Contributor

Fix swapped conditions in local_sqrt_sqr rewrite

The rewrite rule local_sqrt_sqr had the conditions for sqrt(sqr(x))
and sqr(sqrt(x)) reversed. Since prev_op represents the inner
operation and node_op the outer operation, the isinstance checks
were matching the wrong patterns.

Because of this, sqrt(sqr(x)) was rewritten to
switch(x >= 0, x, nan) instead of abs(x), which caused negative
inputs to return NaN. This silently breaks the mathematical identity
sqrt(x^2) = |x| and produces incorrect gradients.

This PR swaps the two isinstance checks so the rewrites match the
correct patterns:

  • sqrt(sqr(x)) -> abs(x)
  • sqr(sqrt(x)) -> switch(x >= 0, x, nan)

Tests were also updated to reflect the correct behavior and now include
numerical checks with negative inputs.

@WHOIM1205
Copy link
Contributor Author

pre-commit.ci autofix

@WHOIM1205
Copy link
Contributor Author

heyy @ricardoV94 is there anything i can improve in this pr

@ricardoV94 ricardoV94 force-pushed the fix-local-sqrt-sqr-math-rewrite branch from 0caa51e to 2e4175f Compare March 12, 2026 12:17
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased and removed a redundant test.

Thanks for the bugfix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants