Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Quantizers

.. autoapiclass:: transformer_engine.pytorch.Float8Quantizer(scale, amax, fp8_dtype, *, rowwise=True, columnwise=True)

.. autoapiclass:: transformer_engine.pytorch.Float8CurrentScalingQuantizer(fp8_dtype, device, *, rowwise=True, columnwise=True, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.Float8CurrentScalingQuantizer(fp8_dtype, device=None, *, rowwise=True, columnwise=True, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.MXFP8Quantizer(fp8_dtype, *, rowwise=True, columnwise=True)

Expand Down
8 changes: 4 additions & 4 deletions tests/pytorch/test_custom_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@ def quantizer_factory(role):
if role in counts:
counts[role] += 1
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")

custom = recipe.CustomRecipe(qfactory=quantizer_factory)

Expand All @@ -319,7 +319,7 @@ def test_factories_return_distinct_instances_and_buffers():

# Two calls should produce distinct quantizer objects and distinct tensor buffers
def factory():
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")

q1 = factory()
q2 = factory()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,10 @@ def forward(
fp8_recipe = fp8_meta["local_recipes"][0]
if fp8_recipe.float8_current_scaling():
S_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=S_quantizer.dtype, device="cuda"
fp8_dtype=S_quantizer.dtype,
)
dP_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=dP_quantizer.dtype, device="cuda"
fp8_dtype=dP_quantizer.dtype,
)

if "2" in qkv_layout or "3" in qkv_layout:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def forward(
cu_seqlens_kv_per_step = [None for _ in range(cp_size)]

fused_attn_backend = None
amax_per_step = None
delayed_scaling_amax_per_step = None
S_quantizer_per_step = [None for _ in range(cp_size)]
O_quantizer_per_step = [None for _ in range(cp_size)]
max_logit_per_step = [None for _ in range(cp_size)]
Expand Down Expand Up @@ -1421,16 +1421,19 @@ def forward(
dP_quantizer,
)

# amax_per_step[0]: amax_s x cp_size
# amax_per_step[1]: amax_o x cp_size
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
# per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True;
# only used to hold temporary scale/amax values (output only, no quantization op)
for i in range(cp_size):
S_quantizer_per_step[i] = S_quantizer.copy()
S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
O_quantizer_per_step[i] = O_quantizer.copy()
O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))

if fp8_recipe.delayed():
# delayed_scaling_amax_per_step[0]: amax_s x cp_size
# delayed_scaling_amax_per_step[1]: amax_o x cp_size
delayed_scaling_amax_per_step = torch.zeros(
(2, cp_size), dtype=torch.float32, device=q.device
)
for i in range(cp_size):
S_quantizer_per_step[i].amax = delayed_scaling_amax_per_step[0][i].reshape((1,))
O_quantizer_per_step[i].amax = delayed_scaling_amax_per_step[1][i].reshape((1,))
else:
# q_f16: torch.Tensor, dtype=fwd_nominal_dtype
# q, k, v: torch.Tensor, dtype=fwd_nominal_dtype
Expand Down Expand Up @@ -1918,9 +1921,9 @@ def forward(
elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:])

# update FP8 quantizers: amax across cp_size steps
if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1)
# update FP8 quantizers: amax across cp_size steps (delayed scaling only)
if fp8 and use_fused_attention and fp8_recipe.delayed():
amax_cp_fwd = delayed_scaling_amax_per_step.amax(dim=1)
S_quantizer.amax.copy_(amax_cp_fwd[0])
O_quantizer.amax.copy_(amax_cp_fwd[1])

Expand Down Expand Up @@ -2034,7 +2037,7 @@ def forward(
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
if ctx.fp8 and fp8_recipe.delayed():
ctx.QKV_quantizer = QKV_quantizer.copy()
ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
ctx.O_quantizer = O_quantizer.copy()
Expand Down Expand Up @@ -2152,7 +2155,7 @@ def backward(ctx, dout, *_args):

# convert out, dout to the right type
fused_attn_backend = None
amax_per_step = None
delayed_scaling_amax_per_step = None
dP_quantizer_per_step = [None for _ in range(cp_size)]
dQKV_quantizer_per_step = [None for _ in range(cp_size)]
buffer_dtype = torch.uint8
Expand Down Expand Up @@ -2224,16 +2227,23 @@ def backward(ctx, dout, *_args):
device=kv.device,
)

# amax_per_step[0]: amax_dp x cp_size
# amax_per_step[1]: amax_dqkv x cp_size
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
# per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True;
# only used to hold temporary scale/amax values (output only, no quantization op)
for i in range(cp_size):
dP_quantizer_per_step[i] = ctx.dP_quantizer.copy()
dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,))
dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy()
dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))

if ctx.fp8_recipe.delayed():
# delayed_scaling_amax_per_step[0]: amax_dp x cp_size
# delayed_scaling_amax_per_step[1]: amax_dqkv x cp_size
delayed_scaling_amax_per_step = torch.zeros(
(2, cp_size), dtype=torch.float32, device=q.device
)
for i in range(cp_size):
dP_quantizer_per_step[i].amax = delayed_scaling_amax_per_step[0][i].reshape(
(1,)
)
dQKV_quantizer_per_step[i].amax = delayed_scaling_amax_per_step[1][i].reshape(
(1,)
)
else:
if isinstance(dout, QuantizedTensorStorage):
dout = dout.dequantize(dtype=bwd_nominal_dtype)
Expand Down Expand Up @@ -2645,9 +2655,10 @@ def backward(ctx, dout, *_args):

# sum up all cp_size for dq, dk, dv
if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1])
if ctx.fp8_recipe.delayed():
amax_cp_bwd = delayed_scaling_amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1])

dq = dq_buffer
if ctx.fp8_recipe.delayed():
Expand Down Expand Up @@ -3647,7 +3658,7 @@ def forward(
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
if ctx.fp8 and fp8_recipe.delayed():
ctx.QKV_quantizer = QKV_quantizer.copy()
ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone()
ctx.O_quantizer = O_quantizer.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2180,10 +2180,7 @@ def print_quantizers(
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
print(f"{label} >> {names[i]:14s}: {type_str}")


def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
Expand Down
39 changes: 27 additions & 12 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ class Float8Quantizer : public Quantizer {

class Float8CurrentScalingQuantizer : public Quantizer {
public:
at::Tensor scale;
at::Tensor scale_inv;
at::Tensor amax;
DType dtype;
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
Expand All @@ -218,33 +216,50 @@ class Float8CurrentScalingQuantizer : public Quantizer {
py::object quantizer, const std::optional<at::Tensor>& first_dims, size_t logical_first_dim,
size_t logical_last_dim) const override;

/*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
/*! @brief Construct an unquantized tensor with an amax buffer.
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
* The provided amax tensor is zeroed out and set on the output tensor.
* Most TE kernels that output amax expect amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
const std::vector<size_t>& shape, DType dtype, at::Tensor amax,
std::optional<at::Tensor> data = std::nullopt);

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

/*! @brief Quantize to FP8 (virtual fallback, allocates local amax/scale) */
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;

/*! @brief Quantize to FP8, skipping local amax computation
/*! @brief Quantize to FP8 using provided amax/scale workspace buffers */
void quantize(const TensorWrapper& input, TensorWrapper& out, at::Tensor amax, at::Tensor scale,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);

/*! @brief Quantize to FP8, skipping local amax computation.
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
* The provided amax tensor is assumed to already hold the local
* amax (e.g. computed by a fused LN kernel). The amax may still
* be reduced across the amax reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, at::Tensor amax,
at::Tensor scale,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);

private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
const std::optional<TensorWrapper>& noop_flag, bool compute_amax,
at::Tensor amax, at::Tensor scale);
};

/*! @brief Extract amax and scale from a quantizer workspace tensor.
*
* Workspace layout: [amax, scale] (2 float32).
*/
inline std::pair<at::Tensor, at::Tensor> split_quantizer_workspace(const at::Tensor& workspace) {
NVTE_CHECK(workspace.numel() >= 2, "Quantizer workspace must have at least 2 float32 elements");
return {workspace.slice(0, 0, 1).contiguous(), workspace.slice(0, 1, 2).contiguous()};
}

class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
Expand Down
Loading