Skip to content

Comments

Fix incorrect gradient in Solve for structured assume_a (sym/pos/her)#1887

Open
WHOIM1205 wants to merge 4 commits intopymc-devs:mainfrom
WHOIM1205:fix-solve-symmetric-gradient
Open

Fix incorrect gradient in Solve for structured assume_a (sym/pos/her)#1887
WHOIM1205 wants to merge 4 commits intopymc-devs:mainfrom
WHOIM1205:fix-solve-symmetric-gradient

Conversation

@WHOIM1205
Copy link
Contributor

Fix gradient handling in Solve for structured assume_a cases


Summary

SolveBase.L_op computes gradients assuming all entries of A are independent. This is correct for assume_a="gen".

However, when using structured assumptions ("sym", "her", "pos"), the solver only reads one triangle of A. The backward pass did not account for this, resulting in incorrect gradients when a pre-structured matrix was passed directly into pt.linalg.solve.

Existing tests did not catch this because they wrapped the input matrix with a symmetrization transform, which masked the issue via the chain rule.


Fix

Updated pytensor/tensor/slinalg.py:

python

Before

Inherited from SolveBase

A_bar = -outer(b_bar, c)

  • all test cases are passed locally
image

Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@jessegrabowski
Copy link
Member

Is this related to #1230 ?

@WHOIM1205
Copy link
Contributor Author

Is this related to #1230 ?

This PR is not directly related to #1230.
#1230 concerns solve_triangular gradients (specifically unit_diag and trans handling), whereas this pr addresses gradient handling in Solve when assume_a is structured (sym/pos/her).
They affect different operators and different gradient paths.

@ricardoV94
Copy link
Member

I think the question was whether it's the same nature of issue


analytic = f(A_val, b_val)

# Numerical gradient: perturb only the read triangle
Copy link
Member

@ricardoV94 ricardoV94 Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we concot a graph that allows us to use verify_grad instead and still works as regression test?

Something whose input is just the triangular entries? I'm assuming they were being half counted?

You can still verify they came out as zeros on an explicit grad fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we concot a graph that allows us to use verify_grad instead and still works as regression test?

Something whose input is just the triangular entries? I'm assuming they were being half counted?

You can still verify they came out as zeros on an explicit grad fn

Thanks for the suggestion i've updated the test accordingly

The manual finite-difference loop has been replaced with utt.verify_grad using a triangular parameterization of the structured entries the symmetric matrix is reconstructed inside the graph and i’ve also added an explicit assertion that the unread triangle has zero gradients
all parametrized cases are passing locally

@ricardoV94
Copy link
Member

hey @ricardoV94
This fixes an error-path imbalance in post_open_standalone() where mutex_ghost was not released on connect() failure. Since the mutex is shared across restore tasks, this could lead to a cross-process deadlock during Unix socket restore. The change is limited to the failure path and does not affect the success flow.

I suppose this is a mistake comment from some other work

@WHOIM1205
Copy link
Contributor Author

WHOIM1205 commented Feb 12, 2026

I think the question was whether it's the same nature of issue

Ah thanks for clarifying
Yes it’s similar in nature in the sense that both issues stem from the gradient not fully respecting structural assumptions made in the forward solve
However they affect different operators and different logic paths:
#1230 concerns solve_triangular (specifically unit_diag and trans handling)
This PR addresses Solve when assume_a is structured (sym/pos/her)
So conceptually similar (structure-aware gradient handling) but technically independent fixes

@WHOIM1205
Copy link
Contributor Author

hey @ricardoV94
This fixes an error-path imbalance in post_open_standalone() where mutex_ghost was not released on connect() failure. Since the mutex is shared across restore tasks, this could lead to a cross-process deadlock during Unix socket restore. The change is limited to the failure path and does not affect the success flow.

I suppose this is a mistake comment from some other work

Apologies that was clearly pasted from another PR by mistake.
Please ignore that comment.
This PR only concerns the gradient handling in Solve for structured assume_a cases

Replace manual finite-difference loop with verify_grad using triangular
parameterization. Add explicit zero-gradient assertion for unread triangle.

Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@WHOIM1205
Copy link
Contributor Author

pre-commit.ci autofix

# triangle contribute to both (i,j) and (j,i) of the effective matrix,
# so we must accumulate the symmetric contribution and zero the unread triangle.
if self.lower:
res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1)
res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.conj().mT, -1)

Otherwise the hermetian case is wrong

@jessegrabowski
Copy link
Member

I think it would be good to have a simple closed-form test. Given:

$$ A = \begin{bmatrix} 1 & \rho \\ \rho & 1 \end{bmatrix} $$

The inverse is analytically computable:

$$ A^{-1} = \frac{1}{1 - \rho^2} \begin{bmatrix} 1 & -\rho \\ -\rho & 1 \end{bmatrix} $$

Given loss = solve(A, b)[1, 0] with b = np.array([[1.], [0.]]), then the loss is:

$$ \mathcal{L} = \frac{\rho}{\rho^2 - 1} $$

And the gradient with respect to $\rho$ is:

$$ \frac{\partial \mathcal{L}}{\partial \rho} = -\frac{1 + \rho^2}{(1 - \rho^2)^2} $$

Checking pytensor:

import pytensor 
import pytensor.tensor as pt

rho = pt.dscalar('rho')
A = pt.stacklists([[1., 0.], [rho, 1.]])
b = pt.stacklists(([[1.], [0]]))
x = pt.linalg.solve(A, b, assume_a='pos')
loss = x[1, 0]

dL_dx = pt.grad(loss, rho)
fn = pytensor.function([rho], dL_dx)
fn(0.88) # array(-19.64815653)

# Analytical:
def expected_grad(rho):
    return -(1 + rho ** 2) / (1 - rho ** 2) ** 2
expected_grad(0.88) # -34.86368894924802

Note though that this only occurs because we passed a non-PSD matrix to solve. It is a numerical quirk that the underlying algorithm is able to treat this matrix as PSD, notwithstanding that it is not. If you instead pass A = pt.stacklists([[1., rho], [rho, 1.]]), you will get the right answer.

@jessegrabowski
Copy link
Member

I've been thinking about this PR more. Do you have a specific use-case that motivated you to work on this? My current view is that our current behavior is not a bug, because:

  • To trigger the issue, you need to "lie" to solve. A lower-triangular matrix is not PSD, you are relying on (undocumented) LAPACK behavior
  • The solution requires at least one more copy of the input matrix to be made (from the tril op)
  • Checking reference implementations (like JAX), our behavior matches theirs.

I am trying to think of a case where you would want to exploit the LAPACK behavior to get some kind of benefit. Simply passing in a lower-triangular matrix is no memory benefit, because you still have to allocate the zeros. You would need a situation like LU, where you have two matrices packed into a single matrix. I have no idea in what context this would naturally arise.

@ricardoV94
Copy link
Member

you are relying on (undocumented) LAPACK behavior

I disagree with this. Scipy clearly says the other part is ignored, and that's why the lower flag exists in solve.

Why do it? I think it's not so much memory savings but avoiding the work of writing the other half of the same values.

I'm fine with the gradient pretending these existed but we should perhaps document it and even show the recipe with tril to combine the contributions?

If this was actually a problem @WHOIM1205 (or anybody else) we can think of adding a special flag to control the behavior but I wouldn't go there until we know the need exists.

Correctly handle the chain rule through the symmetric completion
when assume_a is sym/her/pos. The solver reads only one triangle
and symmetrizes internally, so off-diagonal gradients from the
unread triangle must be folded back into the read triangle.

For Hermitian matrices, the folding uses conjugate transpose
since A_ji = conj(A_ij). For symmetric/pos-def, plain transpose
is used.

Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@WHOIM1205 WHOIM1205 force-pushed the fix-solve-symmetric-gradient branch from 5aa8eda to 7d04a40 Compare February 15, 2026 22:44
@WHOIM1205
Copy link
Contributor Author

pre-commit.ci autofix

@WHOIM1205
Copy link
Contributor Author

@ricardoV94 , @jessegrabowski
Thanks for the thoughtful discussion i went back and carefully re-derived the gradient to make sure the reasoning is sound

The folding of the gradient does not depend on LAPACK internals it follows from the documented behavior of SciPy when assume_a is structured: only one triangle of A is read, and the solver treats the matrix as symmetric or Hermitian based on that triangle. Mathematically, the forward operator is effectively solving with the symmetric (or Hermitian) completion of the read triangle.

Because of that, the gradient needs to apply the chain rule through this completion mapping. Off-diagonal entries in the read triangle control both (i, j) and (j, i) of the effective matrix, while entries in the unread triangle are not free parameters. The folding simply accounts for that dependency.

I have also updated the Hermitian case to use the conjugate transpose when folding, since A_ji = conj(A_ij) in that setting.

For assume_a="pos", I agree that behavior on non-PSD inputs is undefined at the forward level, so the gradient is only meaningful for valid PSD matrices. The change is intended to make the gradient consistent with the structured forward semantics for valid inputs.

I’ll also extend the tests to explicitly cover the Hermitian case with complex inputs so that the conjugation logic is verified.

@jessegrabowski
Copy link
Member

I disagree with this. Scipy clearly says the other part is ignored, and that's why the lower flag exists in solve.

This is a highly technical detail, and non-obvious behavior. Again, this "bug" will never be triggered unless users specifically pass oddly-structured inputs to the function. I await a practical situation where we should pay the extra cost in our gradients. Our gradients which, I remind everyone, are slower than JAX's. And JAX does not apply this "fix".

@WHOIM1205
Copy link
Contributor Author

I disagree with this. Scipy clearly says the other part is ignored, and that's why the lower flag exists in solve.

This is a highly technical detail, and non-obvious behavior. Again, this "bug" will never be triggered unless users specifically pass oddly-structured inputs to the function. I await a practical situation where we should pay the extra cost in our gradients. Our gradients which, I remind everyone, are slower than JAX's. And JAX does not apply this "fix".

Thanks for the detailed feedback I understand the concern.

My intention wasn’t to optimize for oddly structured inputs, but to make the gradient reflect the effective forward semantics when assume_a is structured and only one triangle is actually read. That said i agree this is a fairly subtle case and may not come up often in practice.
If the extra gradient cost isn’t justified without a clear real-world need i’m fine narrowing the scope. At minimum i think the Hermitian case should use the conjugate transpose for correctness. For the symmetric and positive-definite cases
i’m happy to either revert that part or revisit it later if a concrete use case shows up.
Let me know which direction you’d prefer and I’ll adjust the PR accordingly.

@WHOIM1205
Copy link
Contributor Author

@ricardoV94 is there anything i can change in this pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants