From bb3d2f5be6267792ba14fa989002a5d2f79a0d51 Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 28 Nov 2025 08:07:30 +0000 Subject: [PATCH 01/28] update --- include/ck_tile/core/tensor/tile_scatter_gather.hpp | 12 ++++++------ ...ixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 6 +++--- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 4b04fd513db..e6adc7d40b4 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -125,7 +125,7 @@ struct tile_scatter_gather static constexpr auto get_space_filling_curve() { - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; constexpr auto thread_tensor_lengths_ys = to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); @@ -309,7 +309,7 @@ struct tile_scatter_gather CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const { - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, number{}, bool_constant{}); return dst_tensor; @@ -326,7 +326,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -418,7 +418,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; // Precompute invariant values outside loops const auto window_origin = lds_tile.get_window_origin(); @@ -614,7 +614,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; @@ -696,7 +696,7 @@ struct tile_scatter_gather using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + [[maybe_unused]] constexpr auto tile_dstr = TileDstr{}; // printf("off %d\n", page_idx_[I0]); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 17c88e4f08f..159ed4d4c79 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -444,7 +444,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 typename BFlatBlockWindowTmp, typename DequantBFlatWindow> CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window_, - const AElementFunction& a_element_func, + [[maybe_unused]] const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const DequantBFlatWindow& scale_b_flat_window, const index_t num_loop, @@ -606,7 +606,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 scale_b_warp_tensor_pong; using ABlockTile = decltype(load_tile(a_copy_dram_window)); - ABlockTile a_block_tile; + [[maybe_unused]] ABlockTile a_block_tile; enum { @@ -621,7 +621,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 if constexpr(prefill_location & PrefillAfterGemm) async_load_tile(lds_tile_a, dram_tile_a); }; - auto prefill_lds_a_stage2 = [&](auto lds_tile_a) { + auto prefill_lds_a_stage2 = [&]([[maybe_unused]] auto lds_tile_a) { // async_load_fence(); // __builtin_amdgcn_s_waitcnt(0x03fc); // data has been stored in lds, no need more operation. diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index f34c682b0f1..f5954c29abf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -7,7 +7,7 @@ namespace ck_tile { -#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 0 +#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 1 #if defined(__gfx950__) #define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 1 From b99c48da2eceda79b263beee1150434352fe3050 Mon Sep 17 00:00:00 2001 From: yadaish Date: Tue, 18 Nov 2025 10:12:11 +0000 Subject: [PATCH 02/28] mixed-prec flatmm pipeline improve --- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 77 +++++++++++++++++-- 1 file changed, 69 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 159ed4d4c79..7387ae11e00 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -71,7 +71,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 using WG = remove_cvref_t())>; static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS + static constexpr index_t DsReadPreload = 16; // default 8, if using lds, register pressure is alleviated, improve preload +#else static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read +#endif static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t WaveSize = get_warp_size(); @@ -186,11 +190,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { -#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS - // GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A. - // There is no separate DS_WRITE instruction at all. - dswrite_perM = 0; -#endif // Init inst order index_t max_data_inst = dsread_perM > load_perM ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) @@ -360,7 +359,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Calculate ds_read number per M dsread_perM = dsread_per_wg; +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0 + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } +#endif + // Calculate buffer_load number per M +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0 if(mIter < HalfMIter) { load_perM = @@ -375,10 +403,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 ? Aload_rep : 0; } - if((kIter % KPerScaleLoad == 0) && (mIter == 0)) - { - load_perM = load_perM + 1; +#else + if ((kIter * MIterPerWarp + mIter) >= + (KIterPerWarp * MIterPerWarp - m_preload)) { + load_perM = 1; } +#endif + // if((kIter % KPerScaleLoad == 0) && (mIter == 0)) + // { + // load_perM = load_perM + 1; + // } + // SchedulerPerM(dsread_perM, dswrite_perM, load_perM); SchedulerPerM(dsread_perM, dswrite_perM, load_perM); } } @@ -821,6 +856,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) prefill_lds_a_stage2(a_copy_lds_window_pong); @@ -866,12 +903,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 load_tile(a_warp_windows_ping(number{})(number{})); } + // yadai comments out the following + /* // barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { __builtin_amdgcn_s_waitcnt(Bload_total_num); block_sync_lds(); } + */ + + // sync shouble made as early as possible + if constexpr((kIter * MIterPerWarp + mIter) == + (KIterPerWarp * MIterPerWarp - m_preload)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); }); prefill_lds_a_stage1( @@ -928,6 +976,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+2) prefill_lds_a_stage2(a_copy_lds_window_ping); @@ -973,12 +1023,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 load_tile(a_warp_windows_pong(number{})(number{})); } + // yadai comments out the following + /* // barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { __builtin_amdgcn_s_waitcnt(Bload_total_num); block_sync_lds(); } + */ + + // sync shouble made as early as possible + if constexpr((kIter * MIterPerWarp + mIter) == + (KIterPerWarp * MIterPerWarp - m_preload)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); }); prefill_lds_a_stage1( From 151acb30275f9d3197386ace1784c228557e2a31 Mon Sep 17 00:00:00 2001 From: yadaish Date: Wed, 19 Nov 2025 04:13:41 +0000 Subject: [PATCH 03/28] support a16_wint4 moe --- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 19 ++- include/ck_tile/core/numeric/pk_int4.hpp | 18 +++ .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 147 +++++++++++++++++- 4 files changed, 179 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 62fb6bbcb29..709c772f6a1 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -86,7 +86,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config FlatmmConfig::NumWaveGroups, true>; // Preshuffle_ - constexpr bool MXFP4_Pipeline = std::is_same_v; + // TODO(yadai): rename to W4_Pipeline + constexpr bool MXFP4_Pipeline = std::is_same_v | std::is_same_v; if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) { @@ -444,6 +445,22 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[]) FlatmmConfig, ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); } + else if(mixed_prec == "fp16xint4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::half_t, + ck_tile::pk_int4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); + } + else if(mixed_prec == "bf16xint4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::bfloat16_t, + ck_tile::pk_int4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{}); + } else { throw std::runtime_error("Unsupported precision type for gemm1_gate_up!"); diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index fc1caf13ff9..088407b40c7 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -151,6 +151,16 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) return pk_add_f16(bit_cast(lo), bit_cast(SUB)); } + + +CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) +{ + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + float_vec2.x = float_vec2.x * scale; + float_vec2.y = float_vec2.y * scale; + return fp16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; +} + CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); @@ -166,6 +176,14 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) +{ + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + float_vec2.x = float_vec2.x * scale; + float_vec2.y = float_vec2.y * scale; + return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; +} + CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 411cfe81edf..0d5433f6082 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -241,7 +241,7 @@ struct MoeFlatmmKernel IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock; // MXF4_Pipeline only has the of scale B and granularityK is 32 - static constexpr bool MXFP4_Pipeline = std::is_same_v; + static constexpr bool MXFP4_Pipeline = std::is_same_v || std::is_same_v; static constexpr int MXFP4N_Pack = 2; static constexpr int MXFP4K_Pack = 2; diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 7387ae11e00..b42db7434c0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -187,6 +187,135 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + struct DequantizeMxFP4 { + + CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, + const auto& quant_weight_tensor, + const auto& scale_tensor, + auto xdl_nIter, + auto xdl_kIter) { + + auto quant_idx_k = xdl_kIter % number{}; + + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + + auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + + constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); + constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + constexpr int float_mantissa = 23; + + uint32_t uscale = uint32_t(scale.data) << float_mantissa; + + using ComputeV2Type = + std::conditional_t, fp16x2_t, bf16x2_t>; + +#if defined(__gfx950__) + auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) { + if constexpr(std::is_same_v) + { + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4( + pk_mxfp4x4, fscale, int(byte_idx)); + } + else if constexpr(std::is_same_v) + { + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4( + pk_mxfp4x4, fscale, int(byte_idx)); + } + else + { + static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_mxfp4x4_to_compute_v2( + quant_weight_tensor[quant_idx_k], bit_cast(uscale), i)); + }); +#else + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_mxfp4_to_compute_v2( + bit_cast>(quant_weight_tensor[quant_idx_k]) + .at(i), + bit_cast(uscale))); + }); +#endif + return 0; + } + }; + + struct DequantizeINT4 { + + CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, + const auto& quant_weight_tensor, + const auto& scale_tensor, + auto xdl_nIter, + auto xdl_kIter) { + + auto quant_idx_k = xdl_kIter % number{}; + + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + + auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + + constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); + constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + constexpr int float_mantissa = 23; + + uint32_t uscale = uint32_t(scale.data) << float_mantissa; + + using ComputeV2Type = + std::conditional_t, fp16x2_t, bf16x2_t>; + + auto pk_int4_to_compute_v2 = [](auto pk_int4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_int4_t_to_halfx2_t(pk_int4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_int4_t_to_bfloat16x2_t(pk_int4, fscale); + } + else + { + static_assert(sizeof(pk_int4) == 0, "unsupported compute type"); + } + }; + static_for<0, PackedCnt, 1>{}([&](auto i) { + dequant_B_n[xdl_nIter].get_thread_buffer().template set_as( + i, + pk_int4_to_compute_v2( + bit_cast>(quant_weight_tensor[quant_idx_k]) + .at(i), + bit_cast(uscale))); + }); + return 0; + } + }; + + using DequantOp = typename std::conditional, DequantizeMxFP4, DequantizeINT4>::type; + CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { @@ -747,6 +876,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 statically_indexed_array dequant_B_n; + + /* auto dequant_mxfp4 = [&](const auto& quant_weight_tensor, const auto& scale_tensor, auto xdl_nIter, @@ -816,6 +947,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); #endif }; + */ // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; @@ -877,7 +1009,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), @@ -997,7 +1130,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_pong(nIter)(kIter / number{}), scale_b_warp_tensor_pong(nIter / number{})( kIter / number{}), @@ -1124,7 +1258,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), @@ -1185,7 +1320,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_pong(nIter)(kIter / number{}), scale_b_warp_tensor_pong(nIter / number{})( kIter / number{}), @@ -1236,7 +1372,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); if constexpr(mIter == 0) - dequant_mxfp4( + DequantOp{}( + dequant_B_n, b_warp_tensor_ping(nIter)(kIter / number{}), scale_b_warp_tensor_ping(nIter / number{})( kIter / number{}), From 1d7e3a5d9975ec8dcaeb50503f2f23eb72b3db56 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 24 Nov 2025 09:32:44 +0000 Subject: [PATCH 04/28] fix out of lds --- example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp index 458e7ba6434..7e482989d43 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp @@ -13,7 +13,11 @@ // GEMM config with 16x16 warp tile struct A16W4_FlatmmConfig16 { - static constexpr ck_tile::index_t M_Tile = 128; +#if defined(__gfx950__) + static constexpr ck_tile::index_t M_Tile = 256; +#else + static constexpr ck_tile::index_t M_Tile = 64; +#endif static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; From 33f41e5ff71af404457672973c408a7d0f3d322c Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 24 Nov 2025 09:57:31 +0000 Subject: [PATCH 05/28] update --- include/ck_tile/core/numeric/pk_int4.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 088407b40c7..aa67019ead2 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -155,7 +155,9 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) { - auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + TODO(yadai): confirm quanzation algorithm + // auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return fp16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; From 7b2e154c3cf97d73e184f93ca7ff88a2b256b008 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 24 Nov 2025 10:45:56 +0000 Subject: [PATCH 06/28] update --- include/ck_tile/core/numeric/pk_int4.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index aa67019ead2..f0a681cb17b 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -155,7 +155,7 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) { - TODO(yadai): confirm quanzation algorithm + // TODO(yadai): confirm quanzation algorithm // auto float_vec2 = pk_int4_t_to_fp32x2_t(x); auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); float_vec2.x = float_vec2.x * scale; From da42b5af91ef3a94f0bf44ed15e4317ca6dfd043 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 24 Nov 2025 10:59:09 +0000 Subject: [PATCH 07/28] update --- include/ck_tile/core/numeric/pk_int4.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index f0a681cb17b..47297b1ef2d 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -180,7 +180,7 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) { - auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; From 125b7997d5730823d8a228e1a632144d5b159d61 Mon Sep 17 00:00:00 2001 From: yadaish Date: Tue, 25 Nov 2025 05:23:47 +0000 Subject: [PATCH 08/28] change endian seems working --- include/ck_tile/core/numeric/pk_int4.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 47297b1ef2d..3e9df1806fe 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -127,11 +127,14 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_in x_l = x_l > 7 ? x_l - 16 : x_l; x_h = x_h > 7 ? x_h - 16 : x_h; + /* #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; #elif fp32x2_t res = {x_l, x_h}; #endif + */ + fp32x2_t res = {x_l, x_h}; return res; } From d321b3486eaa6c1cc47b7324c0dee5ba50f479f6 Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Wed, 19 Nov 2025 14:02:24 +0000 Subject: [PATCH 09/28] Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 2 +- include/ck_tile/core/arch/generic_memory_space_atomic.hpp | 4 ++++ include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 8 ++++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index 9e0cbda0c00..fa00d024acf 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,7 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( + [[maybe_unused]] const auto rtol_atol = calculate_rtol_atol( K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e56bcadcba2..0ff97bb9a79 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); template <> CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) { +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast(p_dst), x); +#else union U32BF162_ADDR { uint32_t* u32_a; @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); } while(cur_v.u32 != old_v); +#endif } template <> diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 0d5433f6082..d9a96bf75a6 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -623,7 +623,7 @@ struct MoeFlatmmKernel { return make_naive_tensor_view( e_ptr, - make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken, + make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens, IsGateUp ? kargs.N / 2 : kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) From 94ef537be1edda0b1cf05a64bcb201356fdf6787 Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Wed, 19 Nov 2025 15:32:39 +0100 Subject: [PATCH 10/28] correct clang-format --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 5 +++-- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index fa00d024acf..7a52c30c130 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,8 +304,9 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - [[maybe_unused]] const auto rtol_atol = calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); + [[maybe_unused]] const auto rtol_atol = + calculate_rtol_atol( + K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index d9a96bf75a6..15c4c21c86b 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1264,7 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); From 79d7583b1d8757ae686032a7e7cebe1c4c7d85e4 Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Fri, 21 Nov 2025 14:36:24 +0100 Subject: [PATCH 11/28] removed unused rtol_atol variable from example code --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index 7a52c30c130..0bc5b11e40e 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,9 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - [[maybe_unused]] const auto rtol_atol = - calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); + c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; From 10392447796ad3d0a4da8c54b6bd7857c1d0c305 Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Fri, 21 Nov 2025 14:39:32 +0100 Subject: [PATCH 12/28] clang format correction --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index 0bc5b11e40e..e969fd8a113 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -304,7 +304,7 @@ int run_moe_gemm_example_with_layouts(int argc, const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - + c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; From fe690d88fdd916315b86c6af39e1b2ab564060ba Mon Sep 17 00:00:00 2001 From: mohsen saffari Date: Fri, 21 Nov 2025 14:51:39 +0100 Subject: [PATCH 13/28] remove unused varable max_accumulated_value from example --- example/ck_tile/18_flatmm/run_moe_flatmm_example.inc | 3 --- 1 file changed, 3 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index e969fd8a113..053c039cd53 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -302,9 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc, static_cast(per_token_scale_dev_buf.GetDeviceBuffer()), static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; From 412b42cd1b9a4c954252988457e81f566465cffe Mon Sep 17 00:00:00 2001 From: yadaish Date: Tue, 25 Nov 2025 16:21:37 +0000 Subject: [PATCH 14/28] update --- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 15c4c21c86b..946b81c146b 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1264,7 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); + c_scatter_valids[mIter][m0] = (scatter_token_id < (kargs.NumTokens * (IsInputGemm? kargs.TopK : 1))); }); }); From 0642090655f120703e8035f40699f4c4fbe4e1a3 Mon Sep 17 00:00:00 2001 From: yadaish Date: Thu, 27 Nov 2025 14:47:51 +0000 Subject: [PATCH 15/28] update --- include/ck_tile/core/numeric/pk_int4.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 3e9df1806fe..d72647395a8 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -109,11 +109,14 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; + /* #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; #elif fp32x2_t res = {x_l, x_h}; #endif + */ + fp32x2_t res = {x_l, x_h}; return res; } @@ -159,8 +162,8 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) { // TODO(yadai): confirm quanzation algorithm - // auto float_vec2 = pk_int4_t_to_fp32x2_t(x); - auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); + // auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return fp16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; @@ -183,7 +186,7 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) { - auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; From 4aa2407cf02d5eebcf2a751a275d84f10653240e Mon Sep 17 00:00:00 2001 From: yadaish Date: Fri, 28 Nov 2025 11:31:25 +0000 Subject: [PATCH 16/28] update --- .../18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc index f236332d620..476dc70ecda 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc @@ -36,7 +36,7 @@ float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args) std::size_t flop = std::size_t(2) * args.M * args.N * args.K; std::size_t num_byte = sizeof(ADataType) * args.M * args.K + - sizeof(BDataType) * args.N * args.K / PackedSize + + sizeof(BDataType) * args.N * args.K * std::min(args.experts, args.NumTokens * args.TopK) / PackedSize + sizeof(CDataType) * args.M * args.N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; From 26626a68399dabf96e1c5bcd72991b63a88b06f1 Mon Sep 17 00:00:00 2001 From: yadaish Date: Fri, 28 Nov 2025 16:28:15 +0000 Subject: [PATCH 17/28] update --- example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp | 4 ++-- .../18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp index 7e482989d43..6a5b8e9fb77 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp @@ -14,11 +14,11 @@ struct A16W4_FlatmmConfig16 { #if defined(__gfx950__) - static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t M_Tile = 16; #else static constexpr ck_tile::index_t M_Tile = 64; #endif - static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc index 476dc70ecda..38d685d4ab3 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc @@ -36,7 +36,7 @@ float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args) std::size_t flop = std::size_t(2) * args.M * args.N * args.K; std::size_t num_byte = sizeof(ADataType) * args.M * args.K + - sizeof(BDataType) * args.N * args.K * std::min(args.experts, args.NumTokens * args.TopK) / PackedSize + + sizeof(BDataType) * args.N * args.K * std::min(args.NumExperts, args.NumTokens * args.TopK) / PackedSize + sizeof(CDataType) * args.M * args.N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; @@ -188,7 +188,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, int tile_off = i % MPerBlock; if(tile_off < token_per_tile && tokenid < num_tokens * topk) { - sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24); + sorted_token_ids.mData[i] = (tokenid / experts) | ((tokenid % experts) << 24); tokenid++; } else From 2d7a35de3e6fa3b81b78b5eabf6a91c00d177937 Mon Sep 17 00:00:00 2001 From: yadaish Date: Sat, 29 Nov 2025 16:32:33 +0000 Subject: [PATCH 18/28] update --- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 12 +++++++ .../18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp | 31 ++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 709c772f6a1..18607399492 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -515,6 +515,18 @@ int main(int argc, char* argv[]) { return !run_a16w4_moe_flatmm_example(argc, argv); } + else if (warp_tile == 1) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } + else if (warp_tile == 2) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } + else if (warp_tile == 3) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } + else if (warp_tile == 4) { + return !run_a16w4_moe_flatmm_example(argc, argv); + } // else if(warp_tile == 1) // { // return !run_a16w4_moe_flatmm_example(argc, argv); diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp index 6a5b8e9fb77..e3cce789f28 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp @@ -13,11 +13,7 @@ // GEMM config with 16x16 warp tile struct A16W4_FlatmmConfig16 { -#if defined(__gfx950__) - static constexpr ck_tile::index_t M_Tile = 16; -#else static constexpr ck_tile::index_t M_Tile = 64; -#endif static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; @@ -47,6 +43,33 @@ struct A16W4_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct A16W4_FlatmmConfig16_M16 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t kBlockPerCu = 2; +}; + +struct A16W4_FlatmmConfig16_M32 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr ck_tile::index_t kBlockPerCu = 2; +}; + +struct A16W4_FlatmmConfig16_M64 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr ck_tile::index_t kBlockPerCu = 2; +}; + +struct A16W4_FlatmmConfig16_M128 : public A16W4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; +}; + struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16 { static constexpr ck_tile::index_t N_Tile = 128; From 2182364ebb238c3d5c7073ba3f02cfa4b2aea187 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 1 Dec 2025 05:30:02 +0000 Subject: [PATCH 19/28] scale bf16 --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 3 ++- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 946b81c146b..8ccd2c1bcb2 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -638,7 +638,8 @@ struct MoeFlatmmKernel index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); - using ScaleType = std::conditional_t; + // using ScaleType = std::conditional_t; + using ScaleType = std::conditional_t; const auto scale_b_flat_view = make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index b42db7434c0..9925cc06915 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -189,12 +189,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 struct DequantizeMxFP4 { - CK_TILE_DEVICE auto operator()(statically_indexed_array& dequant_B_n, - const auto& quant_weight_tensor, - const auto& scale_tensor, - auto xdl_nIter, - auto xdl_kIter) { - + CK_TILE_DEVICE auto operator()([[maybe_unused]] statically_indexed_array& dequant_B_n, + [[maybe_unused]] const auto& quant_weight_tensor, + [[maybe_unused]] const auto& scale_tensor, + [[maybe_unused]] auto xdl_nIter, + [[maybe_unused]] auto xdl_kIter) { +#if 0 auto quant_idx_k = xdl_kIter % number{}; auto scale_idx_n = xdl_nIter % number{}; @@ -258,6 +258,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 .at(i), bit_cast(uscale))); }); +#endif #endif return 0; } @@ -281,9 +282,13 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; + /* constexpr int float_mantissa = 23; uint32_t uscale = uint32_t(scale.data) << float_mantissa; + */ + + float scale_f32 = type_cast(scale.data); using ComputeV2Type = std::conditional_t, fp16x2_t, bf16x2_t>; @@ -308,7 +313,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 pk_int4_to_compute_v2( bit_cast>(quant_weight_tensor[quant_idx_k]) .at(i), - bit_cast(uscale))); + scale_f32)); }); return 0; } From b2b6fa1aa969d65e74eaa34b172fa15b762a4618 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 1 Dec 2025 05:44:21 +0000 Subject: [PATCH 20/28] update --- .../pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 9925cc06915..25f0de72be1 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -288,7 +288,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 uint32_t uscale = uint32_t(scale.data) << float_mantissa; */ - float scale_f32 = type_cast(scale.data); + float scale_f32 = type_convert(scale.data); using ComputeV2Type = std::conditional_t, fp16x2_t, bf16x2_t>; From b36c7d76a0dcbeefedb3a0240ef810970dfcde5d Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 1 Dec 2025 06:59:48 +0000 Subject: [PATCH 21/28] update --- .../mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 25f0de72be1..66946483e1b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -288,7 +288,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 uint32_t uscale = uint32_t(scale.data) << float_mantissa; */ - float scale_f32 = type_convert(scale.data); + // float scale_f32 = type_convert(scale.data); + float scale_f32 = type_convert(scale); using ComputeV2Type = std::conditional_t, fp16x2_t, bf16x2_t>; @@ -303,6 +304,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 return pk_int4_t_to_bfloat16x2_t(pk_int4, fscale); } else + + + { static_assert(sizeof(pk_int4) == 0, "unsupported compute type"); } From 70596134045d47dbd93aacf32e6b9bebc461c113 Mon Sep 17 00:00:00 2001 From: yadaish Date: Thu, 4 Dec 2025 10:38:45 +0000 Subject: [PATCH 22/28] update --- include/ck_tile/core/numeric/pk_int4.hpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index d72647395a8..088407b40c7 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -109,14 +109,11 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; - /* #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; #elif fp32x2_t res = {x_l, x_h}; #endif - */ - fp32x2_t res = {x_l, x_h}; return res; } @@ -130,14 +127,11 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_in x_l = x_l > 7 ? x_l - 16 : x_l; x_h = x_h > 7 ? x_h - 16 : x_h; - /* #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; #elif fp32x2_t res = {x_l, x_h}; #endif - */ - fp32x2_t res = {x_l, x_h}; return res; } @@ -161,8 +155,6 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) { - // TODO(yadai): confirm quanzation algorithm - // auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); auto float_vec2 = pk_int4_t_to_fp32x2_t(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; From bd6897d4323dbbd3b13130111063c93d7d6d640d Mon Sep 17 00:00:00 2001 From: yadaish Date: Thu, 4 Dec 2025 18:28:17 +0000 Subject: [PATCH 23/28] update --- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 8ccd2c1bcb2..e6f3364bcb6 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -639,7 +639,8 @@ struct MoeFlatmmKernel index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); // using ScaleType = std::conditional_t; - using ScaleType = std::conditional_t; + static constexpr bool IsInt4 = std::is_same_v; + using ScaleType = std::conditional_t, float>; const auto scale_b_flat_view = make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, From ef037eb2593f3e2e67261effbb4b40b53ad472d2 Mon Sep 17 00:00:00 2001 From: yadaish Date: Thu, 4 Dec 2025 18:30:09 +0000 Subject: [PATCH 24/28] update --- .../mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 66946483e1b..a523d09dc2e 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -194,7 +194,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 [[maybe_unused]] const auto& scale_tensor, [[maybe_unused]] auto xdl_nIter, [[maybe_unused]] auto xdl_kIter) { -#if 0 auto quant_idx_k = xdl_kIter % number{}; auto scale_idx_n = xdl_nIter % number{}; @@ -207,7 +206,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; constexpr int float_mantissa = 23; - uint32_t uscale = uint32_t(scale.data) << float_mantissa; + uint32_t uscale = uint32_t(scale) << float_mantissa; using ComputeV2Type = std::conditional_t, fp16x2_t, bf16x2_t>; @@ -258,7 +257,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 .at(i), bit_cast(uscale))); }); -#endif #endif return 0; } From e37dbc7e499ddb128b6615163e953e4f2f8f60af Mon Sep 17 00:00:00 2001 From: yadaish Date: Thu, 4 Dec 2025 18:35:00 +0000 Subject: [PATCH 25/28] update --- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index e6f3364bcb6..88b1ab8c4fc 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -640,7 +640,7 @@ struct MoeFlatmmKernel // using ScaleType = std::conditional_t; static constexpr bool IsInt4 = std::is_same_v; - using ScaleType = std::conditional_t, float>; + using ScaleType = std::conditional_t, float>; const auto scale_b_flat_view = make_naive_tensor_view( reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, From 12d764e999bd223432c10e751e7846bdc8e2f07e Mon Sep 17 00:00:00 2001 From: yadaish Date: Thu, 4 Dec 2025 18:47:53 +0000 Subject: [PATCH 26/28] update --- .../pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index a523d09dc2e..5dbc71d6ea6 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -206,7 +206,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize; constexpr int float_mantissa = 23; - uint32_t uscale = uint32_t(scale) << float_mantissa; + uint32_t uscale = uint32_t(bit_cast(scale)) << float_mantissa; using ComputeV2Type = std::conditional_t, fp16x2_t, bf16x2_t>; From 971ed7da51e171230460460642704c8a2ed7151e Mon Sep 17 00:00:00 2001 From: yadaish Date: Fri, 5 Dec 2025 10:12:13 +0000 Subject: [PATCH 27/28] update --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 88b1ab8c4fc..62a69a16676 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -723,6 +723,7 @@ struct MoeFlatmmKernel constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline; + /* const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, make_tuple(number{}, @@ -730,6 +731,63 @@ struct MoeFlatmmKernel {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / (isNonInterleaveGateUp ? 1 : 2)), 0}); + */ + const auto& b_flat_block_window = [&]() { + // GateUp needs to shuffle weight + if constexpr(IsGateUp) + { + // 1. Get Dimensions + const auto N = b_flat_pad_view.get_tensor_descriptor().get_length(I0); + const auto K = b_flat_pad_view.get_tensor_descriptor().get_length(I1); + + // 2. View Linear N as (2, N/2) -> effectively separating Gate (0) and Up (1) blocks + // Layout becomes: (BlockIdx, RowInBlock, K) + auto v_split = transform_tensor_view( + b_flat_pad_view, + make_tuple(make_unmerge_transform(make_tuple(number<2>{}, N / 2)), + make_pass_through_transform(K)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + // 3. Permute to (N/2, 2, K) -> (RowInBlock, BlockIdx, K) + // This puts Gate(i) and Up(i) adjacent in the view + auto v_permute = transform_tensor_view( + v_split, + make_tuple(make_pass_through_transform(N / 2), + make_pass_through_transform(number<2>{}), + make_pass_through_transform(K)), + make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + // 4. Merge back to (N, K) -> effectively Interleaved View + auto b_interleaved_view = transform_tensor_view( + v_permute, + make_tuple(make_merge_transform(make_tuple(N / 2, number<2>{})), + make_pass_through_transform(K)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // 5. Create Window on the transformed view + return make_tile_window( + b_interleaved_view, + make_tuple(number{}, + number{}), + {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / + (isNonInterleaveGateUp ? 1 : 2)), + 0}); + } + else + { + // Default behavior for Interleaved or non-GateUp + return make_tile_window( + b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / + (isNonInterleaveGateUp ? 1 : 2)), + 0}); + } + }(); const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n; From aee38fdbf35c3bbf6f4badf0f05bf6146f2a0b27 Mon Sep 17 00:00:00 2001 From: yadaish Date: Tue, 9 Dec 2025 00:30:52 +0000 Subject: [PATCH 28/28] update --- .../18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc index 38d685d4ab3..116c55f333a 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc @@ -72,7 +72,8 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, using CDataType = PrecActType; using AccDataType = float; - using ScaleType = ck_tile::e8m0_t; + static constexpr bool IsInt4 = std::is_same_v; + using ScaleType = std::conditional_t; constexpr int ScaleGranularityN = 1; constexpr int ScaleGranularityK = 32;