Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def configure_logging(
def get_run(self, distributed: "Distributed") -> "Run":
from fast_llm.functional.config import TritonConfig

TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda
TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels
TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels
run = Run(config=self, distributed=distributed)
set_global_variables(not self.run.torch_dynamo_enable)
Expand Down
13 changes: 13 additions & 0 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ class TritonConfig:
POINTWISE_BLOCK_SIZE = 1024
MAX_BLOCK_SIZE_BYTES = 65536

@classmethod
def enabled(cls, device: "torch.device|None" = None, default: bool | None = None) -> bool:
if default is False:
return False
from fast_llm.functional.triton import triton_available, triton_interpret

available = triton_available and (device is None or device.type == "cuda" or triton_interpret)
if default is None:
default = available and cls.TRITON_ENABLED
else:
assert available
return default


class MLPRecomputeLevel(enum.StrEnum):
none = "none"
Expand Down
15 changes: 9 additions & 6 deletions fast_llm/functional/entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ def torch_entropy_loss_forward_backward(
logits_scale_factor: float,
target_format: TargetFormat,
entropy_loss_type: EntropyLossType,
group: ProcessGroup | None = None,
temperature: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor | None]: # (), (*batch, vocab)
"""
A wrapper for the pytorch implementation of cross-entropy.
The cross-entropy kernels themselves are well-optimized, but the need for explicit casting
and separate forward and backward kernels lead to poor performance.
"""

assert group is None
# Torch methods require flattened batch dimension.
target = target.flatten() if target_format == TargetFormat.labels else target.flatten(0, -2)
if target_format == TargetFormat.labels:
Expand Down Expand Up @@ -120,7 +121,7 @@ def fused_softmax_base(


@torch.compile
def _fused_reverse_kl_base(
def _fused_reverse_kl_base_from_distribution(
logits: torch.Tensor, # (*batch, vocab)
target: torch.Tensor, # (*batch, vocab)
grad_output: float | None,
Expand Down Expand Up @@ -160,7 +161,7 @@ def _fused_reverse_kl_base(


@torch.compile
def _fused_cross_entropy_base(
def _fused_cross_entropy_base_from_distribution(
logits: torch.Tensor, # (*batch, vocab)
target: torch.Tensor, # (*batch, vocab)
grad_output: float | None,
Expand All @@ -182,7 +183,7 @@ def _fused_cross_entropy_base(
# KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities))
if return_kl_loss:
if target_format == TargetFormat.logits:
target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1)
target_log_probability = target_logits_norm
else:
target_log_probability = torch.log(target)
logits_norm = logits_norm - target_log_probability
Expand All @@ -193,6 +194,8 @@ def _fused_cross_entropy_base(
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)

per_sample_loss = sum_exp_logits.log() - predicted_logits
if return_kl_loss and target_format == TargetFormat.logits:
per_sample_loss = per_sample_loss - sum_exp_target_logits.log()

if grad_output is None:
grad = None
Expand Down Expand Up @@ -301,7 +304,7 @@ def fused_entropy_loss_forward_backward(
group,
)
elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl):
per_sample_loss, grad = _fused_cross_entropy_base(
per_sample_loss, grad = _fused_cross_entropy_base_from_distribution(
logits,
target,
grad_output,
Expand All @@ -312,7 +315,7 @@ def fused_entropy_loss_forward_backward(
return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl,
)
elif entropy_loss_type == EntropyLossType.reverse_kl:
per_sample_loss, grad = _fused_reverse_kl_base(
per_sample_loss, grad = _fused_reverse_kl_base_from_distribution(
logits,
target,
grad_output,
Expand Down
25 changes: 25 additions & 0 deletions fast_llm/functional/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
import torch

from fast_llm.utils import InvalidObject, try_decorate

try:
import triton
import triton.knobs
import triton.language as tl

tl_constexpr = tl.constexpr
TritonConfig = triton.Config
# Use `TRITON_INTERPRET=1` to enable triton on CPU.
triton_interpret = triton.knobs.runtime.interpret
triton_available = torch.cuda.is_available() or triton_interpret
except ImportError as e:
triton = InvalidObject(e)
tl = triton
tl_constexpr = None
TritonConfig = lambda *args, **kwargs: None
triton_interpret = False
triton_available = False

triton_jit = try_decorate(lambda: triton.jit)
triton_autotune = try_decorate(lambda: triton.autotune)

if not triton_available:
tl_arange = None
tl_full = None
elif triton_interpret:
# Workaround for a triton bug.
@triton_jit
def tl_arange(start, end):
return tl.arange(int(start), int(end))

@triton_jit
def tl_full(shape, value, dtype):
return tl.full(tuple(int(x) for x in shape), value, dtype)

else:
tl_arange = tl.arange
tl_full = tl.full
6 changes: 3 additions & 3 deletions fast_llm/functional/triton/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.optim.adamw import adamw # noqa

from fast_llm.functional.config import TritonConfig
from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit
from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit


@triton_jit()
Expand Down Expand Up @@ -37,7 +37,7 @@ def triton_adam_kernel(

# TODO: Int64 ptr only if needed?
block_start = tl.program_id(axis=0).to(tl.int64) * block_size
offsets = block_start + tl.arange(0, block_size)
offsets = block_start + tl_arange(0, block_size)
mask = offsets < numel

params = tl.load(params_ptr + offsets, mask=mask)
Expand Down Expand Up @@ -75,7 +75,7 @@ def triton_adam(
epsilon: float,
use_triton=True,
) -> None:
if not use_triton or (use_triton is None and TritonConfig.TRITON_ENABLED):
if not TritonConfig.enabled(params.device, use_triton):
if noop_flag.item() == 0:
return adamw(
[params],
Expand Down
184 changes: 0 additions & 184 deletions fast_llm/functional/triton/cross_entropy.py

This file was deleted.

Loading