Skip to content

Commit deb7c8a

Browse files
authored
Add cross_entropy example and unit test (#320)
1 parent 7c8a560 commit deb7c8a

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

examples/cross_entropy.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import torch
6+
7+
import helion
8+
from helion._testing import run_example
9+
import helion.language as hl
10+
11+
# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
12+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
13+
# Low memory configuration
14+
TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"}
15+
16+
17+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
18+
def cross_entropy(
19+
logits: torch.Tensor, # [N, V] input logits
20+
labels: torch.Tensor, # [N] target labels
21+
) -> torch.Tensor:
22+
n, v = logits.shape
23+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
24+
25+
# Flatten logits once at the beginning
26+
logits_flat = logits.view(-1)
27+
28+
for tile_n in hl.tile(n):
29+
# Get data for this tile
30+
labels_tile = labels[tile_n] # [tile_size]
31+
base_indices_tile = tile_n.index * v # [tile_size]
32+
33+
# Compute the actual flat indices by adding the label offset
34+
flat_indices = base_indices_tile + labels_tile
35+
36+
# Load the logits at the target indices
37+
logits_at_target = hl.load(logits_flat, [flat_indices])
38+
39+
# Compute log_softmax for numerical stability
40+
# Load the full rows for this tile
41+
logits_rows = logits[tile_n, :] # [tile_size, V]
42+
43+
# Compute log-sum-exp
44+
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
45+
shifted = logits_rows - max_logits
46+
exp_shifted = torch.exp(shifted)
47+
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True)
48+
log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1))
49+
50+
# Cross entropy loss: log_sum_exp - logit_at_target
51+
losses[tile_n] = log_sum_exp - logits_at_target
52+
53+
return losses.mean()
54+
55+
56+
def main() -> None:
57+
"""Run cross entropy benchmark with different input sizes."""
58+
# Test with moderate size
59+
n, v = 128, 1000
60+
logits = torch.randn(n, v, device="cuda", dtype=torch.float32)
61+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
62+
63+
run_example(
64+
cross_entropy,
65+
torch.nn.functional.cross_entropy,
66+
(logits, labels),
67+
kernel_name="helion",
68+
baseline_name="torch",
69+
rtol=1e-4,
70+
atol=1e-4,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
main()

test/test_examples.expected

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch
469469
_launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
470470
return out
471471

472+
--- assertExpectedJournal(TestExamples.test_cross_entropy)
473+
from __future__ import annotations
474+
475+
import torch
476+
import triton
477+
import triton.language as tl
478+
from torch._inductor.runtime.triton_helpers import math as tl_math
479+
from helion.runtime import default_launcher as _default_launcher
480+
481+
@triton.jit
482+
def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
483+
pid_0 = tl.program_id(0)
484+
offset_0 = pid_0
485+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
486+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
487+
mask_1 = indices_1 < v
488+
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
489+
v_0 = v.to(tl.int32)
490+
v_1 = indices_0 * v_0
491+
v_2 = v_1.to(tl.int64)
492+
v_3 = v_2 + labels_tile
493+
logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None)
494+
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
495+
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
496+
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
497+
v_4 = logits_rows - max_logits
498+
v_5 = tl_math.exp(v_4)
499+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0)
500+
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
501+
squeeze = tl.reshape(max_logits, [1])
502+
squeeze_1 = tl.reshape(sum_exp, [1])
503+
v_6 = tl_math.log(squeeze_1)
504+
v_7 = squeeze + v_6
505+
v_8 = v_7 - logits_at_target
506+
tl.store(losses + indices_0 * losses_stride_0, v_8, None)
507+
508+
def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
509+
n, v = logits.shape
510+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
511+
logits_flat = logits.view(-1)
512+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
513+
_launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
514+
return losses.mean()
515+
472516
--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
473517
from __future__ import annotations
474518

test/test_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ def test_softmax_two_pass_block_ptr(self):
267267
)
268268
)
269269

270+
def test_cross_entropy(self):
271+
n, v = 128, 1000
272+
args = (
273+
torch.randn(n, v, device=DEVICE, dtype=torch.float32),
274+
torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long),
275+
)
276+
self.assertExpectedJournal(
277+
check_example(
278+
"cross_entropy",
279+
args,
280+
torch.nn.functional.cross_entropy(*args),
281+
)
282+
)
283+
270284
def test_rms_norm(self):
271285
args = (
272286
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)