Skip to content

Add cross_entropy example and unit test #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import os

import torch

import helion
from helion._testing import run_example
import helion.language as hl

# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
# Low memory configuration
TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s for development on low-VRAM machines - switched to use HELION_DEV_LOW_VRAM to gate this



@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
def cross_entropy(
logits: torch.Tensor, # [N, V] input logits
labels: torch.Tensor, # [N] target labels
) -> torch.Tensor:
n, v = logits.shape
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)

# Flatten logits once at the beginning
logits_flat = logits.view(-1)

for tile_n in hl.tile(n):
# Get data for this tile
labels_tile = labels[tile_n] # [tile_size]
base_indices_tile = tile_n.index * v # [tile_size]

# Compute the actual flat indices by adding the label offset
flat_indices = base_indices_tile + labels_tile

# Load the logits at the target indices
logits_at_target = hl.load(logits_flat, [flat_indices])

# Compute log_softmax for numerical stability
# Load the full rows for this tile
logits_rows = logits[tile_n, :] # [tile_size, V]

# Compute log-sum-exp
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
shifted = logits_rows - max_logits
exp_shifted = torch.exp(shifted)
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True)
log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1))

# Cross entropy loss: log_sum_exp - logit_at_target
losses[tile_n] = log_sum_exp - logits_at_target

return losses.mean()


def main() -> None:
"""Run cross entropy benchmark with different input sizes."""
# Test with moderate size
n, v = 128, 1000
logits = torch.randn(n, v, device="cuda", dtype=torch.float32)
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)

run_example(
cross_entropy,
torch.nn.functional.cross_entropy,
(logits, labels),
kernel_name="helion",
baseline_name="torch",
rtol=1e-4,
atol=1e-4,
)


if __name__ == "__main__":
main()
44 changes: 44 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch
_launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_cross_entropy)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
mask_1 = indices_1 < v
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
v_0 = v.to(tl.int32)
v_1 = indices_0 * v_0
v_2 = v_1.to(tl.int64)
v_3 = v_2 + labels_tile
logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None)
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
v_4 = logits_rows - max_logits
v_5 = tl_math.exp(v_4)
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0)
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
squeeze = tl.reshape(max_logits, [1])
squeeze_1 = tl.reshape(sum_exp, [1])
v_6 = tl_math.log(squeeze_1)
v_7 = squeeze + v_6
v_8 = v_7 - logits_at_target
tl.store(losses + indices_0 * losses_stride_0, v_8, None)

def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
n, v = logits.shape
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
logits_flat = logits.view(-1)
_RDIM_SIZE_1 = triton.next_power_of_2(v)
_launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return losses.mean()

--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
from __future__ import annotations

Expand Down
14 changes: 14 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,20 @@ def test_softmax_two_pass_block_ptr(self):
)
)

def test_cross_entropy(self):
n, v = 128, 1000
args = (
torch.randn(n, v, device=DEVICE, dtype=torch.float32),
torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long),
)
self.assertExpectedJournal(
check_example(
"cross_entropy",
args,
torch.nn.functional.cross_entropy(*args),
)
)

def test_rms_norm(self):
args = (
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),
Expand Down
Loading