diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu index cdd6f9cf14..cc0d760a39 100644 --- a/tests/cpp_distributed/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -204,7 +204,7 @@ class CommGemmFixure : public ::testing::TestWithParam { std::vector bdata(k * n); std::generate(bdata.begin(), bdata.end(), [&rng, &dist, b_scale] { return static_cast(dist(rng) * b_scale); }); - std::vector biasdata(m * n); + std::vector biasdata(m); std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] { return static_cast(dist(rng) * bias_scale); }); @@ -213,7 +213,7 @@ class CommGemmFixure : public ::testing::TestWithParam { : MakeFromData(adata, 0, 0, m, k, m, a_scale); auto gb = transb ? MakeFromData(bdata, 0, 0, n, k, n, b_scale) : MakeFromData(bdata, 0, 0, k, n, k, b_scale); - auto gbias = MakeFromData(biasdata, 0, 0, m, n, m, bias_scale); + auto gbias = MakeFromData(biasdata, 0, 0, m, 1, m, bias_scale); auto gd = Make(m, n, d_scale); auto gaux = Make(m, n, d_scale); @@ -226,8 +226,8 @@ class CommGemmFixure : public ::testing::TestWithParam { dims.b_cols_num, dims.b_rows_num, n, b_scale) : MakeFromData(bdata, dims.b_rows_start, dims.b_cols_start, dims.b_rows_num, dims.b_cols_num, k, b_scale); - auto bias = MakeFromData(biasdata, dims.d_rows_start, dims.d_cols_start, - dims.d_rows_num, dims.d_cols_num, m, bias_scale); + auto bias = MakeFromData(biasdata, dims.d_rows_start, 0, dims.d_rows_num, 1, m, + bias_scale); auto d = Make(dims.d_rows_num, dims.d_cols_num, d_scale); auto aux = Make(dims.d_rows_num, dims.d_cols_num, d_scale); @@ -237,7 +237,7 @@ class CommGemmFixure : public ::testing::TestWithParam { accumulate, 0 /*comm_sm_count*/, stream); auto workspace = Make(1, 32 << 20, 1.0); nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb, - grad, workspace.data(), accumulate, false /* use_split_accumulator */, + grad, workspace.data(), accumulate, true /* use_split_accumulator */, 0 /* math_sm_count */, stream); NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); @@ -253,7 +253,7 @@ class CommGemmFixure : public ::testing::TestWithParam { dims.d_rows_num, dims.d_cols_num, m); NVTE_CHECK(out.size() == out_golden.size()); for (size_t i = 0; i < out.size(); ++i) { - EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol * k); + EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol); } } @@ -427,35 +427,35 @@ INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm, INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs, testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, - false, false, 64, 128, 256, 5e-2}, + false, false, 64, 128, 256, 7e-2}, Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, - false, true, 64, 128, 256, 5e-2}, + false, true, 64, 128, 256, 7e-2}, Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, - true, false, 64, 128, 256, 5e-2}, + true, false, 64, 128, 256, 7e-2}, Params{DType::kBFloat16, DType::kBFloat16, - DType::kBFloat16, false, false, 64, 128, 256, 5e-2}, + DType::kBFloat16, false, false, 64, 128, 256, 6e-1}, Params{DType::kBFloat16, DType::kBFloat16, - DType::kBFloat16, false, true, 64, 128, 256, 5e-2}, + DType::kBFloat16, false, true, 64, 128, 256, 6e-1}, Params{DType::kBFloat16, DType::kBFloat16, - DType::kBFloat16, true, false, 64, 128, 256, 5e-2}, + DType::kBFloat16, true, false, 64, 128, 256, 6e-1}, Params{DType::kFloat8E4M3, DType::kFloat8E4M3, - DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + DType::kFloat16, true, false, 64, 128, 256, 1e-1}, Params{DType::kFloat8E4M3, DType::kFloat8E5M2, - DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + DType::kFloat16, true, false, 64, 128, 256, 7e-2}, Params{DType::kFloat8E5M2, DType::kFloat8E4M3, - DType::kFloat16, true, false, 64, 128, 256, 5e-2}), + DType::kFloat16, true, false, 64, 128, 256, 7e-2}), &ParamSuffix); INSTANTIATE_TEST_SUITE_P( GemmAr, GemmAr, testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64, - 64 * 4, 64 * 4, 5e-2}, + 64 * 4, 64 * 4, 7e-2}, Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64, - 64 * 4, 64 * 4, 5e-2}, + 64 * 4, 64 * 4, 1e-3}, Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false, - 128, 128 * 4, 128 * 4, 5e-2}, + 128, 128 * 4, 128 * 4, 1.5e-1}, Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false, - 128, 128 * 4, 128 * 4, 5e-2}, + 128, 128 * 4, 128 * 4, 1.5e-1}, Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false, - 128, 128 * 4, 128 * 4, 5e-2}), + 128, 128 * 4, 128 * 4, 1.5e-1}), &ParamSuffix);