Fix incorrect gradient in Solve for structured assume_a (sym/pos/her)#1887
Fix incorrect gradient in Solve for structured assume_a (sym/pos/her)#1887WHOIM1205 wants to merge 4 commits intopymc-devs:mainfrom
Conversation
Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
|
Is this related to #1230 ? |
This PR is not directly related to #1230. |
|
I think the question was whether it's the same nature of issue |
tests/tensor/test_slinalg.py
Outdated
|
|
||
| analytic = f(A_val, b_val) | ||
|
|
||
| # Numerical gradient: perturb only the read triangle |
There was a problem hiding this comment.
Can we concot a graph that allows us to use verify_grad instead and still works as regression test?
Something whose input is just the triangular entries? I'm assuming they were being half counted?
You can still verify they came out as zeros on an explicit grad fn
There was a problem hiding this comment.
Can we concot a graph that allows us to use
verify_gradinstead and still works as regression test?Something whose input is just the triangular entries? I'm assuming they were being half counted?
You can still verify they came out as zeros on an explicit grad fn
Thanks for the suggestion i've updated the test accordingly
The manual finite-difference loop has been replaced with utt.verify_grad using a triangular parameterization of the structured entries the symmetric matrix is reconstructed inside the graph and i’ve also added an explicit assertion that the unread triangle has zero gradients
all parametrized cases are passing locally
I suppose this is a mistake comment from some other work |
Ah thanks for clarifying |
Apologies that was clearly pasted from another PR by mistake. |
Replace manual finite-difference loop with verify_grad using triangular parameterization. Add explicit zero-gradient assertion for unread triangle. Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
|
pre-commit.ci autofix |
pytensor/tensor/slinalg.py
Outdated
| # triangle contribute to both (i,j) and (j,i) of the effective matrix, | ||
| # so we must accumulate the symmetric contribution and zero the unread triangle. | ||
| if self.lower: | ||
| res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1) |
There was a problem hiding this comment.
| res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1) | |
| res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.conj().mT, -1) |
Otherwise the hermetian case is wrong
|
I think it would be good to have a simple closed-form test. Given: The inverse is analytically computable: Given And the gradient with respect to Checking pytensor: import pytensor
import pytensor.tensor as pt
rho = pt.dscalar('rho')
A = pt.stacklists([[1., 0.], [rho, 1.]])
b = pt.stacklists(([[1.], [0]]))
x = pt.linalg.solve(A, b, assume_a='pos')
loss = x[1, 0]
dL_dx = pt.grad(loss, rho)
fn = pytensor.function([rho], dL_dx)
fn(0.88) # array(-19.64815653)
# Analytical:
def expected_grad(rho):
return -(1 + rho ** 2) / (1 - rho ** 2) ** 2
expected_grad(0.88) # -34.86368894924802Note though that this only occurs because we passed a non-PSD matrix to solve. It is a numerical quirk that the underlying algorithm is able to treat this matrix as PSD, notwithstanding that it is not. If you instead pass |
|
I've been thinking about this PR more. Do you have a specific use-case that motivated you to work on this? My current view is that our current behavior is not a bug, because:
I am trying to think of a case where you would want to exploit the LAPACK behavior to get some kind of benefit. Simply passing in a lower-triangular matrix is no memory benefit, because you still have to allocate the zeros. You would need a situation like LU, where you have two matrices packed into a single matrix. I have no idea in what context this would naturally arise. |
I disagree with this. Scipy clearly says the other part is ignored, and that's why the lower flag exists in solve. Why do it? I think it's not so much memory savings but avoiding the work of writing the other half of the same values. I'm fine with the gradient pretending these existed but we should perhaps document it and even show the recipe with tril to combine the contributions? If this was actually a problem @WHOIM1205 (or anybody else) we can think of adding a special flag to control the behavior but I wouldn't go there until we know the need exists. |
Correctly handle the chain rule through the symmetric completion when assume_a is sym/her/pos. The solver reads only one triangle and symmetrizes internally, so off-diagonal gradients from the unread triangle must be folded back into the read triangle. For Hermitian matrices, the folding uses conjugate transpose since A_ji = conj(A_ij). For symmetric/pos-def, plain transpose is used. Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
5aa8eda to
7d04a40
Compare
|
pre-commit.ci autofix |
for more information, see https://pre-commit.ci
|
@ricardoV94 , @jessegrabowski The folding of the gradient does not depend on LAPACK internals it follows from the documented behavior of SciPy when assume_a is structured: only one triangle of A is read, and the solver treats the matrix as symmetric or Hermitian based on that triangle. Mathematically, the forward operator is effectively solving with the symmetric (or Hermitian) completion of the read triangle. Because of that, the gradient needs to apply the chain rule through this completion mapping. Off-diagonal entries in the read triangle control both (i, j) and (j, i) of the effective matrix, while entries in the unread triangle are not free parameters. The folding simply accounts for that dependency. I have also updated the Hermitian case to use the conjugate transpose when folding, since A_ji = conj(A_ij) in that setting. For assume_a="pos", I agree that behavior on non-PSD inputs is undefined at the forward level, so the gradient is only meaningful for valid PSD matrices. The change is intended to make the gradient consistent with the structured forward semantics for valid inputs. I’ll also extend the tests to explicitly cover the Hermitian case with complex inputs so that the conjugation logic is verified. |
This is a highly technical detail, and non-obvious behavior. Again, this "bug" will never be triggered unless users specifically pass oddly-structured inputs to the function. I await a practical situation where we should pay the extra cost in our gradients. Our gradients which, I remind everyone, are slower than JAX's. And JAX does not apply this "fix". |
Thanks for the detailed feedback I understand the concern. My intention wasn’t to optimize for oddly structured inputs, but to make the gradient reflect the effective forward semantics when assume_a is structured and only one triangle is actually read. That said i agree this is a fairly subtle case and may not come up often in practice. |
|
@ricardoV94 is there anything i can change in this pr |
Fix gradient handling in
Solvefor structuredassume_acasesSummary
SolveBase.L_opcomputes gradients assuming all entries ofAare independent. This is correct forassume_a="gen".However, when using structured assumptions (
"sym","her","pos"), the solver only reads one triangle ofA. The backward pass did not account for this, resulting in incorrect gradients when a pre-structured matrix was passed directly intopt.linalg.solve.Existing tests did not catch this because they wrapped the input matrix with a symmetrization transform, which masked the issue via the chain rule.
Fix
Updated
pytensor/tensor/slinalg.py:python
Before
Inherited from SolveBase
A_bar = -outer(b_bar, c)