[PyTorch] [torch.compile] Remove module reference from autograd function args#2791
[PyTorch] [torch.compile] Remove module reference from autograd function args#2791pggPL wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Extract weight quantization into standalone `quantize_weight()` function in base.py, eliminating the need to pass `self` (nn.Module) into autograd functions. Each op's autograd function now receives/returns Optional[Tensor] weight workspaces instead, with cache management handled by the nn.Module before/after the autograd call. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
…autograd_args Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor # Conflicts: # transformer_engine/pytorch/module/base.py
for more information, see https://pre-commit.ci
No callers remain after the quantize_weight refactor. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Greptile SummaryThis PR refactors the Confidence Score: 5/5Safe to merge; refactoring is mechanically correct across all four modules with no observable behavioural change. All backward return-value counts match updated forward input counts. The recompute path in grouped_linear.py: returns new workspaces as a Python list rather than individual tensor outputs, which may warrant a follow-up for full Important Files Changed
Sequence DiagramsequenceDiagram
participant M as Module.forward()
participant WS as _fp8_workspaces
participant AF as AutogradFunction.forward()
participant QW as quantize_weight()
M->>WS: get(cache_name)
WS-->>M: weight_workspace (or None)
M->>AF: apply(weight, weight_workspace, inp, ...)
AF->>QW: quantize_weight(tensor, quantizer, workspace, cache=True)
alt Cache HIT
QW-->>AF: (workspace_updated_inplace, None)
else Cache MISS
QW-->>AF: (new_workspace, new_workspace)
end
AF-->>M: (out, new_weight_workspace)
alt new_weight_workspace is not None
M->>M: new_weight_workspace.detach()
M->>WS: store(cache_name, new_weight_workspace)
end
Reviews (6): Last reviewed commit: "grouped linear fix" | Re-trigger Greptile |
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
/te-ci pytorch |
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
The
torch.autograd.functions in TE modules havemoduleargument, which is used for weight cache.This will not work with torch.compile. This PR changed that with direct tensor pass and return from operator and cache update outside
torch.autograd.function.Type of change
Checklist: