Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_executable(test_operator
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_dequantize_nvfp4.cu
test_transpose.cu
test_cast_transpose.cu
test_cast_transpose_current_scaling.cu
Expand Down
206 changes: 206 additions & 0 deletions tests/cpp/operator/test_dequantize_mxfp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/swizzle.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"

Expand Down Expand Up @@ -369,7 +370,137 @@ void performTest_x2(const size_t rows,
compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol);
}

// Dequantize with GEMM-swizzled scales (single dimension)
template <typename InputType, typename OutputType>
void performTest_x1_swizzled(const size_t rows,
const size_t cols,
const bool rowwise,
const bool colwise)
{
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;

const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;

const size_t unpadded_blocks_Y_rowwise = rows;
const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols);
const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X_colwise = cols;

const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise,
scale_tensor_alignment_Y_rowwise);
const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise,
scale_tensor_alignment_X_rowwise);
const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise,
scale_tensor_alignment_Y_colwise);
const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise,
scale_tensor_alignment_X_colwise);

const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise;
const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise;

const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise;
const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise;

Tensor input_compact_scales("input_compact_scales", std::vector<size_t>{ rows, cols }, itype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);

Tensor input_swizzled_scales("input_swizzled_scales", std::vector<size_t>{ rows, cols }, itype,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
input_swizzled_scales.set_with_gemm_swizzled_scales(true);

Tensor output("output", std::vector<size_t>{ rows, cols }, otype, true, false);

std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
std::unique_ptr<fp8e8m0[]> scales = std::make_unique<fp8e8m0[]>(blocks_num);

fill_tensor_data<InputType>(input_compact_scales, scales.get(), scales.get(), rowwise, colwise,
rows, cols, blocks_num_rowwise, blocks_num_colwise);

const size_t data_bytes = rows * cols * sizeof(InputType);
if (rowwise && data_bytes > 0) {
cudaMemcpy(input_swizzled_scales.rowwise_dptr(), input_compact_scales.rowwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice);
}
if (colwise && data_bytes > 0) {
cudaMemcpy(input_swizzled_scales.columnwise_dptr(), input_compact_scales.columnwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice);
}

if (data_bytes > 0) {
nvte_swizzle_scaling_factors(input_compact_scales.data(), input_swizzled_scales.data(), 0);
}

nvte_dequantize(input_swizzled_scales.data(), output.data(), 0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

InputType *data_ptr = rowwise
? input_compact_scales.rowwise_cpu_dptr<InputType>()
: input_compact_scales.columnwise_cpu_dptr<InputType>();

compute_ref_x1<InputType, OutputType>(data_ptr,
ref_output.get(),
scales.get(),
rows,
cols,
block_size_rows,
block_size_cols,
scales_stride);

auto [atol, rtol] = getTolerances(otype);
compareResults("output_swizzled", output, ref_output.get(), true, atol, rtol);
}

// Quantize with swizzled scales, then dequantize — round-trip test
template <typename InputType, typename IntermediateType>
void performTest_quantize_then_dequantize_swizzled(const size_t rows,
const size_t cols,
const bool rowwise,
const bool colwise)
{
using namespace test;
using EncodingType = fp32;
DType in_type = TypeInfo<InputType>::dtype;
DType intermed_type = TypeInfo<IntermediateType>::dtype;
DType out_type = TypeInfo<InputType>::dtype;

std::unique_ptr<InputType[]> output_cpu = std::make_unique<InputType[]>(rows * cols);

Tensor input("input", std::vector<size_t>{ rows, cols }, in_type);
Tensor quantized("quantized", std::vector<size_t>{ rows, cols }, intermed_type,
rowwise, colwise, NVTE_MXFP8_1D_SCALING);
quantized.set_with_gemm_swizzled_scales(true);

Tensor output("output", std::vector<size_t>{ rows, cols }, out_type, true, false);

fillCase<EncodingType>(&input, InputsFillCase::uniform);

if (rows > 0 && cols > 0) {
nvte_quantize(input.data(), quantized.data(), 0);
cudaDeviceSynchronize();
}

nvte_dequantize(quantized.data(), output.data(), 0);
cudaDeviceSynchronize();

const size_t copy_size = sizeof(InputType) * rows * cols;
cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost);

auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

auto [atol, rtol] = getTolerances(intermed_type);
compareResults("Quantize-Dequantize-Swizzled", input, output_cpu.get(), true, atol, rtol);
}

std::vector<std::pair<size_t, size_t>> tensor_dims = {
{0, 128},
{0, 256},
{1, 16},
{16, 48},
{65, 96},
Expand Down Expand Up @@ -470,3 +601,78 @@ INSTANTIATE_TEST_SUITE_P(
return name;
}
);

/*****************************************************************************
* Swizzled-scale dequantization tests
*****************************************************************************/

class DequantizeMXFP8SwizzledTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
std::pair<size_t, size_t>,
transformer_engine::DType,
transformer_engine::DType,
bool>> {};

TEST_P(DequantizeMXFP8SwizzledTestSuite, TestDequantizeMXFP8Swizzled)
{
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}

using namespace transformer_engine;
using namespace test;

const auto tensor_size = std::get<0>(GetParam());
const auto block_size = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const bool quantize_then_dequantize = std::get<4>(GetParam());

const bool rowwise = block_size.second != 1;
const bool colwise = block_size.first != 1;

if (rowwise && colwise) {
GTEST_SKIP();
}

if (rowwise && tensor_size.second % 32 != 0) {
GTEST_SKIP();
}
if (colwise && tensor_size.first % 32 != 0) {
GTEST_SKIP();
}

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
if (quantize_then_dequantize) {
performTest_quantize_then_dequantize_swizzled<OutputType, InputType>(
tensor_size.first, tensor_size.second, rowwise, colwise);
} else {
performTest_x1_swizzled<InputType, OutputType>(
tensor_size.first, tensor_size.second, rowwise, colwise);
}
);
);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
DequantizeMXFP8SwizzledTestSuite,
::testing::Combine(
::testing::ValuesIn(tensor_dims),
::testing::ValuesIn(block_sizes),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(false)),
[](const testing::TestParamInfo<DequantizeMXFP8SwizzledTestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
std::to_string(std::get<1>(info.param).first) + "X" +
std::to_string(std::get<1>(info.param).second) + "X" +
test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" +
(std::get<4>(info.param) ? "QD_Swizzled" : "D_Swizzled");
return name;
}
);
Loading