Skip to content

[PyTorch] [torch.compile] Remove module reference from autograd function args#2791

Open
pggPL wants to merge 11 commits intoNVIDIA:mainfrom
pggPL:remove_module_from_autograd_args
Open

[PyTorch] [torch.compile] Remove module reference from autograd function args#2791
pggPL wants to merge 11 commits intoNVIDIA:mainfrom
pggPL:remove_module_from_autograd_args

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Mar 23, 2026

The torch.autograd.functions in TE modules have module argument, which is used for weight cache.
This will not work with torch.compile. This PR changed that with direct tensor pass and return from operator and cache update outside torch.autograd.function.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 4 commits March 23, 2026 14:55
Extract weight quantization into standalone `quantize_weight()` function
in base.py, eliminating the need to pass `self` (nn.Module) into
autograd functions. Each op's autograd function now receives/returns
Optional[Tensor] weight workspaces instead, with cache management
handled by the nn.Module before/after the autograd call.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…autograd_args

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor

# Conflicts:
#	transformer_engine/pytorch/module/base.py
No callers remain after the quantize_weight refactor.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 23, 2026

Greptile Summary

This PR refactors the torch.autograd.Function subclasses in all TE linear modules to remove the module reference from non_tensor_args. Weight workspaces are now passed explicitly as tensor arguments and returned as additional output tensors; the _fp8_workspaces cache is updated outside the autograd function. A new quantize_weight helper in base.py centralises the cache-hit / cache-miss logic. The fp8_meta access in _LayerNormMLP._recompute is also corrected from ctx.other_args[\"module\"].fp8_meta to ctx.other_args[\"fp8_meta\"].

Confidence Score: 5/5

Safe to merge; refactoring is mechanically correct across all four modules with no observable behavioural change.

All backward return-value counts match updated forward input counts. The recompute path in _LayerNormMLP correctly receives None workspaces and cache_weight=False. The fp8_meta reference fix prevents a crash under delayed-scaling + gradient checkpointing. No P0/P1 issues found; remaining observations are P2.

grouped_linear.py: returns new workspaces as a Python list rather than individual tensor outputs, which may warrant a follow-up for full torch.compile parity with the other modules.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Adds _is_weight_workspace_valid and quantize_weight helpers; removes get_weight_workspace from TransformerEngineBaseModule. Logic is well-documented and correctly handles cache-hit, cache-miss, FSDP gather, and DebugQuantizer paths.
transformer_engine/pytorch/module/linear.py Adds weight_workspace as a second forward input; returns (out, new_weight_workspace). Backward correctly returns None for the workspace gradient slot. Module-level cache update moved outside the autograd function.
transformer_engine/pytorch/module/layernorm_linear.py Same workspace-as-tensor-arg pattern as _Linear; backward return tuple extended with None for weight_workspace; input/output counts match.
transformer_engine/pytorch/module/layernorm_mlp.py Two workspace tensor args added (fc1/fc2). _forward now always returns a fixed 4-tuple, avoiding the previous variable-arity return. _recompute correctly unpacks 7 saved tensors and passes None workspaces during recomputation. Critical fix: ctx.other_args["fp8_meta"] replaces the removed ctx.other_args["module"].fp8_meta.
transformer_engine/pytorch/module/grouped_linear.py Moves from in-place list mutation inside non_tensor_args to returning new_workspaces as a second output. Workspaces are still returned as a Python list rather than individual tensors; this differs from the other three modules and may have different torch.compile traceability.

Sequence Diagram

sequenceDiagram
    participant M as Module.forward()
    participant WS as _fp8_workspaces
    participant AF as AutogradFunction.forward()
    participant QW as quantize_weight()

    M->>WS: get(cache_name)
    WS-->>M: weight_workspace (or None)
    M->>AF: apply(weight, weight_workspace, inp, ...)
    AF->>QW: quantize_weight(tensor, quantizer, workspace, cache=True)
    alt Cache HIT
        QW-->>AF: (workspace_updated_inplace, None)
    else Cache MISS
        QW-->>AF: (new_workspace, new_workspace)
    end
    AF-->>M: (out, new_weight_workspace)
    alt new_weight_workspace is not None
        M->>M: new_weight_workspace.detach()
        M->>WS: store(cache_name, new_weight_workspace)
    end
Loading

Reviews (6): Last reviewed commit: "grouped linear fix" | Re-trigger Greptile

pggPL and others added 5 commits March 23, 2026 15:31
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 30, 2026

/te-ci pytorch

@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Apr 10, 2026

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 10, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Apr 10, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant