-
Notifications
You must be signed in to change notification settings - Fork 694
Add MXFP8 attention #2719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cyanguwa
wants to merge
197
commits into
NVIDIA:main
Choose a base branch
from
cyanguwa:add_mxfp8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add MXFP8 attention #2719
Changes from all commits
Commits
Show all changes
197 commits
Select commit
Hold shift + click to select a range
e0ae107
initial implementation for mxfp8
cyanguwa 23434b5
semi-working FP8; broken F16
cyanguwa dbb68b8
clean up last commit
cyanguwa c627231
comment out F16 pass
cyanguwa d27a267
Merge branch 'NVIDIA:main' into mxfp8_fwd
cyanguwa 3f3b9e6
pull in grouped_quantize for MXFP8
cyanguwa 850b16e
grouped tensor - pytorch
cyanguwa 46f2eb1
quantize mxfp8
cyanguwa e86207c
fix shapes/strides
cyanguwa 4e854d5
fix unfused; clean up
cyanguwa cd06398
split d to d_qk/d_v; attempt at bwd
cyanguwa d2a63a1
merge main
cyanguwa 730a472
fix last merge
cyanguwa d9ff566
update FE
cyanguwa 2b264d7
attempt at SWA/MLA
cyanguwa 2008bed
remove prints
cyanguwa 239f58a
remove leftover prints
cyanguwa f44a775
Revert "update FE"
cyanguwa 965572b
update FE
cyanguwa 91025c7
fix MLA O strides; add bottom_right_diagonal
cyanguwa d655e7e
attempt at bwd
cyanguwa a4ab691
fix get_quantizers; attempt at bwd
cyanguwa a85070d
fix fprop; add o_format
cyanguwa 8909b35
attempt at bwd with o_format/d_out_format/dqkv_layout
cyanguwa 90a636c
fix dtype/o_format/etc in bwd calls
cyanguwa 8c72dea
fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8
cyanguwa 5f23edd
fix upon last commit for paddedsizes
cyanguwa 18c5580
add mxfp8 env var
cyanguwa 6847645
disable FA for mxfp8
cyanguwa c5a98d5
add mha test
cyanguwa 7e61ecd
attempt at bwd; force determinism; fix shapes
cyanguwa 6d468da
remove prints
cyanguwa 9f8e856
update FE
cyanguwa facef79
update FE from pre-merge branch to post-merge develop
cyanguwa fd33cca
allow MXFP8 linear + f16 attn
cyanguwa 5079d55
test cp a2a
cyanguwa 06b7d49
remove prints temporarily
cyanguwa 7fbe399
test cp p2p
cyanguwa aa05a2a
minor fixes for mla
cyanguwa 00e6693
open up a2a for mla
cyanguwa b8d28ce
test ag
cyanguwa d6ecadc
tweaks for last commit
cyanguwa 3ac48cd
enable mla ag
cyanguwa 169ae8a
merge main
cyanguwa 5d4fa5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 81c18fa
fix merge
cyanguwa 1f14f2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ccebe77
fix merge
cyanguwa c52c5f4
revert to main grouped tensor impl
cyanguwa 5b776ec
minor tweaks to return to main
cyanguwa 4eee2bc
remove prints
cyanguwa 8500121
fix combine_and_quantize for f16
cyanguwa 0c2c466
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6744aee
minor tweaks
cyanguwa 4cec878
tweak tests
cyanguwa 5c8e939
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7b6b364
fix ds descale_o
cyanguwa 462eb4f
Revert "fix ds descale_o"
cyanguwa 77995d2
minor fixes for p2p and ag
cyanguwa 586b698
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1e7cd70
tweak cp test skips
cyanguwa 6d7766a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6d33db8
update FE
cyanguwa 92e6aac
fix bwd KV tensors
cyanguwa 3cb6f0e
tweak recipe control and backend selection
cyanguwa c57ece4
tweak quantizer logic
cyanguwa 87a7e1e
minor fixes after last two commits
cyanguwa 3b015f3
improve generate strides
cyanguwa 6717e1a
minor fixes for previous commit
cyanguwa c918b9d
fix bwd for current/delayed
cyanguwa af60216
tweak test configs
cyanguwa 6ac41d2
fix dO/dO_f16 strides
cyanguwa 0a0722f
fix tests: SWA logic/test configs
cyanguwa 89b44f8
fix ag
cyanguwa 7c0ba7f
add fp8 sink attn
cyanguwa e68f785
fix a2a comm for F16
cyanguwa ae53980
remove nan/inf print in test
cyanguwa 4b314e7
fix fa a2a
cyanguwa 4b5d623
fix fa a2a+p2p f16
cyanguwa fdab7db
update FE to include new fixes
cyanguwa 39b57e9
fix thd for bwd
cyanguwa dc49479
refactor a2a for fu/fa
cyanguwa dea59e4
update FE to fix d64
cyanguwa 9da8ec9
refactor ag
cyanguwa a250b20
refactor p2p/a2a+p2p; mostly regarding shapes
cyanguwa 630545e
add shadow f16 fwd
cyanguwa a78ea9a
update FE to fix SWA/BRCM
cyanguwa 59eff74
switch to GH FE temporarily
cyanguwa 6472e66
merge main
cyanguwa 1691747
switch back to GL FE
cyanguwa d41eca3
update FE to latest commit
cyanguwa e0b65a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e51ec9f
update group tensor usage after merge main
cyanguwa 7bb40d5
env vars for qdq(q,k), o_f16 tests
cyanguwa 29c2f4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c10f05c
allow other recipes than mxfp8
cyanguwa 773c678
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0ef408b
fix grouped tensor for MLA
cyanguwa 4429e58
change cp test configs
cyanguwa 08af36b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4dd1418
add shadow f16 bwd
cyanguwa ad4d4da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f2266f4
fix a2a+p2p for sbhd
cyanguwa 1674b0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 712d4f9
fix last commit and causal flag for fa
cyanguwa f9463e2
enable fp8 sink and disable fp8_mha
cyanguwa 299bc63
minor cleanup for cp/non-cp
cyanguwa ed62903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 94ae209
update FE for FP8 sink
cyanguwa a9028b2
fix TE for FP8 sink
cyanguwa a6f56e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 706095f
temporary: random sink/print sink
cyanguwa 4c004ee
Revert "temporary: random sink/print sink"
cyanguwa e023d3b
replace d_out_format with do_format
cyanguwa 7577919
fix compare_and_assert for None cases
cyanguwa f0b1e2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ee388e5
remove logic for b and simplify logic for dqkv types
cyanguwa cacc59d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] de82fe1
minor fix for ndim_q/kv
cyanguwa 706012a
add explanation of fp8_output/grad in MHA
cyanguwa 746010e
tidy up FP8 checks for bhsd/learnable
cyanguwa 2283081
remove leading underscores in nvte_convert_qkv_format
cyanguwa e693e6f
simplify logic in generateMatrixStridesWithLayout
cyanguwa edf1b2a
clean up strides/ifelse-recipe logic
cyanguwa 09b21ee
tweak checks in utils.py
cyanguwa 49a54c0
tweak UnfusedDPA
cyanguwa e5d49d2
enable testing for ag+swa and disable fp8_mha
cyanguwa 2c63d83
tweak FusedAttn, fp8/f16 tensor naming/docstring
cyanguwa 7f62b98
replace d_out_format with do_format
cyanguwa 4b9240c
fix lint
cyanguwa 2a21a3a
clean up a2a
cyanguwa a18cd7c
clean up ag
cyanguwa a19ccb3
clean up p2p/a2a+p2p
cyanguwa 4ba2ef5
tweak test configs
cyanguwa 875931c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f0bf680
qdq dO in bwd shadow f16 path
cyanguwa 2d80d38
tweak qdq dO logic
cyanguwa 0cf9738
remove prints in shadow paths
cyanguwa 813d39d
update FE to allow non-determinism
cyanguwa bdc0c47
fuse qkv transposes; first pass
cyanguwa e69a06a
remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nve…
cyanguwa aab8856
allocate contiguous block for qkv
cyanguwa 78055e4
fix grouped tensor row/col scale_inv offsets
cyanguwa d8f9ac9
use fused permute kernels
cyanguwa ca53769
quantize row/col as needed in fwd/bwd, non-cp/cp
cyanguwa f19e852
Revert "quantize row/col as needed in fwd/bwd, non-cp/cp"
cyanguwa 2d403f9
Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp"
cyanguwa f9e4e20
fix v_col format when row is quantized
cyanguwa fde366a
add back necessary bwd quants for shadow paths/cp a2a
cyanguwa 81f723d
remove ZInv for all layouts except T3HD
cyanguwa 89daa49
fix cp p2p with zinv
cyanguwa 60740fa
temporarily switch to GH FE main
cyanguwa 7fdf269
Merge branch 'main' into add_mxfp8
cyanguwa a7ff000
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b0db79e
switch back to GL FE
cyanguwa f662a4a
fix ag after merge main
cyanguwa cbf6edd
add condition for qdq(do) to not affect other tests
cyanguwa 0642251
fix custom_mha_fp8 test
cyanguwa e6ffc6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fd9a750
fix amax dqkv
cyanguwa 4f2e4f4
fix fp8_recipe in DPA utils
cyanguwa 3869145
remove use of amax for mxfp8
cyanguwa 641c05c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 59db112
add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16
cyanguwa f1d1809
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c491908
enable sink attn + FP8 in CP
cyanguwa 6af3105
update FE to GH v1.22.0
cyanguwa 508044b
fix for inconsistent kwarg name in permute to grouped tensor
cyanguwa 2532a50
add TMA permute
cyanguwa d7c27f6
Revert "add TMA permute"
cyanguwa ba411a2
TMA load for bhsd transposes
cyanguwa 5ada28d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4a47e4d
Merge branch 'main' into add_mxfp8
cyanguwa 6911aba
fix some lint
cyanguwa a27e30d
temp: quant+perm+swizzle, rope, perm_fused
cyanguwa e87102b
remove mla_rope for now; clean up quant+permute+pad_swizzle; create m…
cyanguwa e440291
fix last commit
cyanguwa 18af952
implement narrow-m for col swizzle; reorder to pad+perm+swizzle
cyanguwa 4da30d8
fused pad into perm; remove at::zeros as zeros done in perm kernels
cyanguwa 0853858
remove shadow code
cyanguwa 12a5687
minor fix for permute shapes
cyanguwa c268df3
check smem size before entering narrow-k/m kernels
cyanguwa 4fe6089
expand permute to multi_tensor_
cyanguwa cfe09e8
refactor qkv/do quant; create a fast_path call
cyanguwa 88155ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3e517ea
Merge branch 'main' into add_mxfp8
cyanguwa 13ff957
fix lint
cyanguwa b26ddd9
Merge branch 'main' into add_mxfp8
cyanguwa 8b733ef
cleanup grouped tensor fix
cyanguwa 4047a1b
remove _with_amax for create_unquantized_tensor
cyanguwa 881b037
fix last commit
cyanguwa b13bfa5
reimplement inplace_multi_tensor_swizzle
cyanguwa d05fcb8
fix last commit; set swizzled flag in python
cyanguwa 182f4b5
remove permute_to_grouped_tensor_bwd; clean up fwd
cyanguwa f53755b
add doxygen for multi_tensor_swizzle
cyanguwa 779da99
clean up nvte_convert_qkv_format
cyanguwa f73370a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule cudnn-frontend
updated
89 files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1803,20 +1803,45 @@ def get_model(dtype, config): | |
| return outputs | ||
|
|
||
|
|
||
| attn_mask_type = "causal" | ||
| model_configs_fp8_vs_f16 = { | ||
| # test: ModelConfig(b, sq, hq, dqk) | ||
| "fp8_9": ModelConfig(2, 2048, 16, 128), | ||
| "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), | ||
| "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), | ||
| "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), | ||
| "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), | ||
| "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), | ||
| "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), | ||
| "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), | ||
| "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), | ||
| "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), | ||
| "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), | ||
| "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), | ||
| "fp8_9": ModelConfig( | ||
| 2, | ||
| 4096, | ||
| 128, | ||
| 192, | ||
| head_dim_v=128, | ||
| ), | ||
| "fp8_10": ModelConfig( | ||
| 1, | ||
| 4096, | ||
| 128, | ||
| 192, | ||
| head_dim_v=128, | ||
| attn_mask_type="causal", | ||
| ), | ||
| "fp8_11": ModelConfig( | ||
| 2, | ||
| 4096, | ||
| 128, | ||
| 192, | ||
| head_dim_v=128, | ||
| attn_mask_type="causal_bottom_right", | ||
| ), | ||
|
Comment on lines
+1809
to
+1831
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it precommit doing something strange? The previous form was way easier to read. |
||
| "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), | ||
| "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), | ||
| "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), | ||
| "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), | ||
| "fp8_16": ModelConfig( | ||
| 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" | ||
| ), | ||
| "fp8_17": ModelConfig( | ||
| 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" | ||
| ), | ||
| "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), | ||
| "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), | ||
| "fp8_20": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), | ||
| } | ||
|
|
||
| param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] | ||
|
|
@@ -1833,7 +1858,7 @@ def get_model(dtype, config): | |
| @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) | ||
| @pytest.mark.parametrize("RoPE", [True, False]) | ||
| @pytest.mark.parametrize("is_training", [True, False]) | ||
| @pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) | ||
| @pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) | ||
| def test_mha_fp8_vs_f16( | ||
| dtype, | ||
| model, | ||
|
|
@@ -1864,6 +1889,12 @@ def test_mha_fp8_vs_f16( | |
| fp8_dpa=True, | ||
| fp8_mha=True, | ||
| ) | ||
| elif scaling_mode == "mxfp8": | ||
| fp8_recipe = recipe.MXFP8BlockScaling( | ||
| fp8_format=recipe.Format.E4M3, | ||
| fp8_dpa=True, | ||
| fp8_mha=False, | ||
| ) | ||
| fp8_meta = {} | ||
| fp8_meta["recipe"] = fp8_recipe | ||
| available_backends, _, _ = get_available_attention_backends( | ||
|
|
@@ -2083,7 +2114,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) | ||
| @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) | ||
| @pytest.mark.parametrize("is_training", [True, False]) | ||
| @pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) | ||
| @pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) | ||
| def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): | ||
| """Test DotProductAttention module in FP8""" | ||
| config = model_configs_fp8_vs_f16[model] | ||
|
|
@@ -2115,6 +2146,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal | |
| fp8_format=recipe.Format.HYBRID, | ||
| fp8_dpa=True, | ||
| ) | ||
| elif scaling_mode == "mxfp8": | ||
| fp8_recipe = recipe.MXFP8BlockScaling( | ||
| fp8_format=recipe.Format.E4M3, | ||
| fp8_dpa=True, | ||
| fp8_mha=False, | ||
| ) | ||
| fp8_meta = {} | ||
| fp8_meta["recipe"] = fp8_recipe | ||
| available_backends, _, _ = get_available_attention_backends( | ||
|
|
@@ -2186,7 +2223,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal | |
| atol = 5e-1 | ||
| rtol = 5e-2 | ||
| rmse_tol = 0.11 | ||
| bwd_names = ["dq", "dk", "dv"] | ||
| bwd_names = ["dq", "dk", "dv", "d_softmax_offset"] | ||
| if flash_attn_supported and fused_attn_supported_f16: | ||
| logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) | ||
| logging.debug("========== {:^25s} ==========".format("forward output")) | ||
|
|
@@ -2275,7 +2312,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| with quantized_model_init(enabled=fp8_dpa): | ||
| dpa = DotProductAttention( | ||
| config.num_heads, | ||
| config.head_dim_qk, | ||
| (config.head_dim_qk, config.head_dim_v), | ||
| num_gqa_groups=config.num_gqa_groups, | ||
| attention_dropout=config.dropout_p, | ||
| sequence_parallel=False, | ||
|
|
@@ -2285,6 +2322,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| layer_number=1, | ||
| attention_type="self", | ||
| qkv_format=qkv_format, | ||
| softmax_type=config.softmax_type, | ||
| ).to(dtype=dtype, device="cuda") | ||
| if not is_training: | ||
| dpa = dpa.eval() | ||
|
|
@@ -2320,7 +2358,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| "skv": config.max_seqlen_kv, | ||
| "h": config.num_heads, | ||
| "hg": config.num_gqa_groups, | ||
| "d": config.head_dim_qk, | ||
| "dqk": config.head_dim_qk, | ||
| "dv": config.head_dim_v, | ||
| "t": cu_seqlens_q[-1], | ||
| "tg": cu_seqlens_kv[-1], | ||
| "3": 3, | ||
|
|
@@ -2336,6 +2375,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| layout = layout.replace("s", "skv") | ||
| layout = layout.replace("h", "hg") | ||
| layout = layout.replace("t", "tg") | ||
| if i == 2: | ||
| layout = layout.replace("d", "dv") | ||
| else: | ||
| layout = layout.replace("d", "dqk") | ||
| tensor_shape = [dim_to_num[j] for j in layout.split("_")] | ||
| if config.dropout_p == 0.0: | ||
| tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") | ||
|
|
@@ -2360,6 +2403,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
|
|
||
| qkv_format_kv = "_".join(qkv_format) | ||
| qkv_format_kv = qkv_format_kv.replace("s", "sq") | ||
| qkv_format_kv = qkv_format_kv.replace("d", "dv") | ||
| out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] | ||
| out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] | ||
| out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") | ||
|
|
@@ -2370,21 +2414,24 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: | |
| inp[1], | ||
| inp[2], | ||
| qkv_format=qkv_format, | ||
| window_size=config.window_size, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_kv=cu_seqlens_kv, | ||
| max_seqlen_q=config.max_seqlen_q, | ||
| max_seqlen_kv=config.max_seqlen_kv, | ||
| attn_mask_type=config.attn_mask_type, | ||
| checkpoint_core_attention=False, | ||
| core_attention_bias_type=config.attn_bias_type, | ||
| fp8_output=fp8_dpa, | ||
| ) | ||
| if is_training: | ||
| out.backward(out_grad) | ||
| d_softmax_offset = None | ||
| if is_training and config.softmax_type != "vanilla": | ||
| d_softmax_offset = dpa.softmax_offset.grad | ||
|
|
||
| if is_training: | ||
| return out, (inp[0].grad, inp[1].grad, inp[2].grad) | ||
| return out, (None, None, None) | ||
| return out, (inp[0].grad, inp[1].grad, inp[2].grad, d_softmax_offset) | ||
| return out, (None, None, None, d_softmax_offset) | ||
|
|
||
|
|
||
| model_configs_fp8 = { | ||
|
|
@@ -2636,6 +2683,8 @@ def forward( | |
| quantization_params=qkv_quantizer, | ||
| use_split_accumulator=_2X_ACC_FPROP, | ||
| ) | ||
| qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd" | ||
| o_format = "bshd" if cudnn_frontend_version == 1 else "thd" | ||
| qkv = qkv.view(-1, 3, h, d) | ||
| qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() | ||
| torch.save(qkv_fp16, "qkv.pt") | ||
|
|
@@ -2664,7 +2713,8 @@ def forward( | |
| attn_scale=None, | ||
| dropout=p_dropout, | ||
| fast_zero_fill=fast_zero_fill, | ||
| qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", | ||
| qkv_layout=qkv_layout, | ||
| o_format=o_format, | ||
| attn_bias_type="no_bias", | ||
| attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", | ||
| rng_gen=None, | ||
|
|
@@ -2687,6 +2737,8 @@ def forward( | |
| ctx.num_heads = num_heads | ||
| ctx.mask_type = mask_type | ||
| ctx.dtype = inp.dtype | ||
| ctx.qkv_layout = qkv_layout | ||
| ctx.o_format = o_format | ||
|
|
||
| ctx.dQKV_quantizer = dQKV_quantizer | ||
| ctx.dO_quantizer = dO_quantizer | ||
|
|
@@ -2704,7 +2756,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |
| (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx) | ||
|
|
||
| proj_dgrad = ctx.dO_quantizer(grad_output) | ||
| fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) | ||
|
|
||
| dq, dk, dv, *rest = fused_attn_bwd( | ||
| ctx.max_s, | ||
|
|
@@ -2717,7 +2768,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |
| out, | ||
| proj_dgrad.view_as(out), | ||
| ctx.qkv_dtype, | ||
| fp8_dtype_backward, | ||
| ctx.aux_ctx_tensors, | ||
| FusedAttnBackend["FP8"], | ||
| None, | ||
|
|
@@ -2728,7 +2778,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |
| attn_scale=None, | ||
| dropout=ctx.p_dropout, | ||
| fast_zero_fill=ctx.fast_zero_fill, | ||
| qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", | ||
| qkv_layout=ctx.qkv_layout, | ||
| o_format=ctx.o_format, | ||
| do_format=ctx.o_format, | ||
| dqkv_layout=ctx.qkv_layout, | ||
| attn_bias_type="no_bias", | ||
| attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", | ||
| ) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a test, shouldn't we just skip the test that has scaling mode set to mxfp8 and fp8_mha? Otherwise we would just run the same test (without fp8_mha) again. Also we should make sure that in the actual attention code we disable fp8_mha when mxfp8 recipe is chosen. We should also log that choice in the debug log.