diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 7b9b711c22..97f6cb3b88 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 +Subproject commit 97f6cb3b88cacff507cca1280db5650a457d92b3 diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3d6e3a0aac..627ee55e99 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -339,7 +339,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && + cudnn_runtime_version >= 91100) || + // 9.20: any head_dim + Blackwell + fprop/bprop + non_paged + any sq + (sm_arch_ >= 100 && cudnn_runtime_version >= 92000 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD)) && // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&