Skip to content

Commit deee62f

Browse files
committed
Add cross_entropy example and unit test
stack-info: PR: #320, branch: yf225/stack/28
1 parent 5eaa289 commit deee62f

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

examples/cross_entropy.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Cross entropy with index computation that works within Helion's constraints."""
2+
3+
from __future__ import annotations
4+
5+
import torch
6+
7+
import helion
8+
from helion._testing import run_example
9+
import helion.language as hl
10+
from helion.utils import get_gpu_memory_info
11+
12+
# TritonBench configuration - adjust based on available GPU memory
13+
if get_gpu_memory_info()[0] < 16.0:
14+
# Low memory configuration for GPUs with less than 16GB
15+
TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"}
16+
17+
18+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
19+
def cross_entropy(
20+
logits: torch.Tensor, # [N, V] input logits
21+
labels: torch.Tensor, # [N] target labels
22+
) -> torch.Tensor:
23+
n, v = logits.shape
24+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
25+
26+
# Pre-compute base indices: [0, V, 2V, 3V, ...]
27+
base_indices = torch.arange(n, device=logits.device) * v
28+
29+
# Flatten logits once at the beginning
30+
logits_flat = logits.view(-1)
31+
32+
for tile_n in hl.tile(n):
33+
# Get data for this tile
34+
labels_tile = labels[tile_n] # [tile_size]
35+
base_indices_tile = base_indices[tile_n] # [tile_size]
36+
37+
# Compute the actual flat indices by adding the label offset
38+
# flat_index[i] = base_indices[i] + labels[i] = i*V + labels[i]
39+
flat_indices = base_indices_tile + labels_tile
40+
41+
# Load the logits at the target indices
42+
logits_at_target = hl.load(logits_flat, [flat_indices])
43+
44+
# Compute log_softmax for numerical stability
45+
# Load the full rows for this tile
46+
logits_rows = logits[tile_n, :] # [tile_size, V]
47+
48+
# Compute log-sum-exp
49+
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
50+
shifted = logits_rows - max_logits
51+
exp_shifted = torch.exp(shifted)
52+
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True)
53+
log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1))
54+
55+
# Cross entropy loss: log_sum_exp - logit_at_target
56+
losses[tile_n] = log_sum_exp - logits_at_target
57+
58+
return losses.mean()
59+
60+
61+
def main() -> None:
62+
"""Run cross entropy benchmark with different input sizes."""
63+
# Test with moderate size
64+
n, v = 128, 1000
65+
logits = torch.randn(n, v, device="cuda", dtype=torch.float32)
66+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
67+
68+
run_example(
69+
cross_entropy,
70+
torch.nn.functional.cross_entropy,
71+
(logits, labels),
72+
kernel_name="helion",
73+
baseline_name="torch",
74+
rtol=1e-4,
75+
atol=1e-4,
76+
)
77+
78+
79+
if __name__ == "__main__":
80+
main()

test/test_examples.expected

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,57 @@ def _concat2d_dim1_make_precompiler(x: torch.Tensor, y: torch.Tensor):
459459
from helion.runtime.precompile_shim import make_precompiler
460460
return make_precompiler(_concat2d_dim1_kernel)(x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
461461

462+
--- assertExpectedJournal(TestExamples.test_cross_entropy)
463+
from __future__ import annotations
464+
465+
import torch
466+
import triton
467+
import triton.language as tl
468+
from torch._inductor.runtime.triton_helpers import math as tl_math
469+
470+
@triton.jit
471+
def _cross_entropy_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
472+
pid_0 = tl.program_id(0)
473+
offset_0 = pid_0
474+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
475+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
476+
mask_1 = indices_1 < v
477+
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
478+
base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None)
479+
v_0 = base_indices_tile + labels_tile
480+
logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None)
481+
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
482+
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
483+
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
484+
v_1 = logits_rows - max_logits
485+
v_2 = tl_math.exp(v_1)
486+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0)
487+
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
488+
squeeze = tl.reshape(max_logits, [1])
489+
squeeze_1 = tl.reshape(sum_exp, [1])
490+
v_3 = tl_math.log(squeeze_1)
491+
v_4 = squeeze + v_3
492+
v_5 = v_4 - logits_at_target
493+
tl.store(losses + indices_0 * losses_stride_0, v_5, None)
494+
495+
def cross_entropy(logits: torch.Tensor, labels: torch.Tensor):
496+
n, v = logits.shape
497+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
498+
base_indices = torch.arange(n, device=logits.device) * v
499+
logits_flat = logits.view(-1)
500+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
501+
_cross_entropy_kernel[n,](labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
502+
return losses.mean()
503+
504+
def _cross_entropy_make_precompiler(logits: torch.Tensor, labels: torch.Tensor):
505+
n, v = logits.shape
506+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
507+
base_indices = torch.arange(n, device=logits.device) * v
508+
logits_flat = logits.view(-1)
509+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
510+
from helion.runtime.precompile_shim import make_precompiler
511+
return make_precompiler(_cross_entropy_kernel)(labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
512+
462513
--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
463514
from __future__ import annotations
464515

test/test_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ def test_softmax_two_pass_block_ptr(self):
267267
)
268268
)
269269

270+
def test_cross_entropy(self):
271+
n, v = 128, 1000
272+
args = (
273+
torch.randn(n, v, device=DEVICE, dtype=torch.float32),
274+
torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long),
275+
)
276+
self.assertExpectedJournal(
277+
check_example(
278+
"cross_entropy",
279+
args,
280+
torch.nn.functional.cross_entropy(*args),
281+
)
282+
)
283+
270284
def test_rms_norm(self):
271285
args = (
272286
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)