Skip to content

Commit 28a5978

Browse files
committed
Add cross_entropy example and unit test
stack-info: PR: #320, branch: yf225/stack/28
1 parent 7a981c0 commit 28a5978

File tree

3 files changed

+355
-0
lines changed

3 files changed

+355
-0
lines changed

examples/cross_entropy.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import torch
6+
7+
import helion
8+
from helion._testing import run_example
9+
import helion.language as hl
10+
11+
# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
12+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
13+
# Low memory configuration
14+
TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"}
15+
16+
17+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
18+
def cross_entropy(
19+
logits: torch.Tensor, # [N, V] input logits
20+
labels: torch.Tensor, # [N] target labels
21+
) -> torch.Tensor:
22+
n, v = logits.shape
23+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
24+
25+
# Flatten logits once at the beginning
26+
logits_flat = logits.view(-1)
27+
28+
for tile_n in hl.tile(n):
29+
# Get data for this tile
30+
labels_tile = labels[tile_n] # [tile_size]
31+
base_indices_tile = tile_n.index * v # [tile_size]
32+
33+
# Compute the actual flat indices by adding the label offset
34+
# flat_index[i] = base_indices[i] + labels[i] = i*V + labels[i]
35+
flat_indices = base_indices_tile + labels_tile
36+
37+
# Load the logits at the target indices
38+
logits_at_target = hl.load(logits_flat, [flat_indices])
39+
40+
# Compute log_softmax for numerical stability
41+
# Load the full rows for this tile
42+
logits_rows = logits[tile_n, :] # [tile_size, V]
43+
44+
# Compute log-sum-exp
45+
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
46+
shifted = logits_rows - max_logits
47+
exp_shifted = torch.exp(shifted)
48+
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True)
49+
log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1))
50+
51+
# Cross entropy loss: log_sum_exp - logit_at_target
52+
losses[tile_n] = log_sum_exp - logits_at_target
53+
54+
return losses.mean()
55+
56+
57+
def main() -> None:
58+
"""Run cross entropy benchmark with different input sizes."""
59+
# Test with moderate size
60+
n, v = 128, 1000
61+
logits = torch.randn(n, v, device="cuda", dtype=torch.float32)
62+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
63+
64+
run_example(
65+
cross_entropy,
66+
torch.nn.functional.cross_entropy,
67+
(logits, labels),
68+
kernel_name="helion",
69+
baseline_name="torch",
70+
rtol=1e-4,
71+
atol=1e-4,
72+
)
73+
74+
75+
if __name__ == "__main__":
76+
main()

test/test_examples.expected

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch
469469
_launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
470470
return out
471471

472+
--- assertExpectedJournal(TestExamples.test_cross_entropy)
473+
from __future__ import annotations
474+
475+
import torch
476+
import triton
477+
import triton.language as tl
478+
from torch._inductor.runtime.triton_helpers import math as tl_math
479+
from helion.runtime import default_launcher as _default_launcher
480+
481+
@triton.jit
482+
def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
483+
pid_0 = tl.program_id(0)
484+
offset_0 = pid_0
485+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
486+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
487+
mask_1 = indices_1 < v
488+
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
489+
v_0 = v.to(tl.int32)
490+
v_1 = indices_0 * v_0
491+
v_2 = v_1.to(tl.int64)
492+
v_3 = v_2 + labels_tile
493+
logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None)
494+
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
495+
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
496+
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
497+
v_4 = logits_rows - max_logits
498+
v_5 = tl_math.exp(v_4)
499+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0)
500+
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
501+
squeeze = tl.reshape(max_logits, [1])
502+
squeeze_1 = tl.reshape(sum_exp, [1])
503+
v_6 = tl_math.log(squeeze_1)
504+
v_7 = squeeze + v_6
505+
v_8 = v_7 - logits_at_target
506+
tl.store(losses + indices_0 * losses_stride_0, v_8, None)
507+
508+
def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
509+
n, v = logits.shape
510+
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
511+
logits_flat = logits.view(-1)
512+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
513+
_launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
514+
return losses.mean()
515+
472516
--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
473517
from __future__ import annotations
474518

@@ -530,6 +574,94 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
530574
_launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
531575
return out.view(*x.size(), embedding_dim)
532576

577+
--- assertExpectedJournal(TestExamples.test_fp8_attention)
578+
from __future__ import annotations
579+
580+
import math
581+
import torch
582+
import triton
583+
import triton.language as tl
584+
from torch._inductor.runtime import triton_helpers
585+
from torch._inductor.runtime.triton_compat import libdevice
586+
587+
@triton.jit
588+
def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
589+
pid_0 = tl.program_id(0)
590+
offset_0 = pid_0
591+
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
592+
for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
593+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
594+
m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
595+
l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
596+
acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
597+
q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
598+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
599+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
600+
q_tile_copy = q_tile
601+
m_i_copy = m_i
602+
l_i_copy = l_i
603+
acc_copy = acc
604+
q_tile_copy_0 = q_tile_copy
605+
m_i_copy_0 = m_i_copy
606+
l_i_copy_0 = l_i_copy
607+
acc_copy_0 = acc_copy
608+
k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
609+
k_tile_t = tl.permute(k_tile, [1, 0])
610+
mm = tl.dot(q_tile_copy_0, k_tile_t)
611+
v_0 = mm.to(tl.float32)
612+
v_1 = 0.18033688
613+
v_2 = v_0 * v_1
614+
qk_max = tl.max(v_2, 1)
615+
v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
616+
subscript = v_3[:, None]
617+
v_4 = v_2 - subscript
618+
v_5 = libdevice.exp2(v_4)
619+
l_ij = tl.sum(v_5, 1)
620+
v_6 = m_i_copy_0 - v_3
621+
v_7 = libdevice.exp2(v_6)
622+
v_8 = l_i_copy_0 * v_7
623+
l_i = v_8 + l_ij
624+
subscript_1 = v_7[:, None]
625+
v_10 = acc_copy_0 * subscript_1
626+
v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
627+
v_11 = v_5.to(tl.float8e5)
628+
v_t = tl.permute(v_tile, [1, 0])
629+
mm_1 = tl.dot(v_11, v_t)
630+
v_12 = mm_1.to(tl.float32)
631+
acc = v_10 + v_12
632+
m_i = v_3
633+
subscript_2 = l_i[:, None]
634+
v_14 = acc / subscript_2
635+
tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None)
636+
637+
def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
638+
"""FP8 attention kernel processing batch*heads in parallel."""
639+
batch_heads = q.size(0)
640+
seq_len = q.size(1)
641+
head_dim = q.size(2)
642+
out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
643+
sm_scale = 1.0 / math.sqrt(float(head_dim))
644+
sm_scale = sm_scale * 1.44269504
645+
_RDIM_SIZE_2 = 64
646+
_BLOCK_SIZE_1 = 64
647+
_BLOCK_SIZE_3 = 64
648+
_fp8_attention_kernel_kernel[8,](q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
649+
return out
650+
651+
def _fp8_attention_kernel_make_precompiler(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
652+
"""FP8 attention kernel processing batch*heads in parallel."""
653+
batch_heads = q.size(0)
654+
seq_len = q.size(1)
655+
head_dim = q.size(2)
656+
out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
657+
sm_scale = 1.0 / math.sqrt(float(head_dim))
658+
sm_scale = sm_scale * 1.44269504
659+
_RDIM_SIZE_2 = 64
660+
_BLOCK_SIZE_1 = 64
661+
_BLOCK_SIZE_3 = 64
662+
from helion.runtime.precompile_shim import make_precompiler
663+
return make_precompiler(_fp8_attention_kernel_kernel)(q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
664+
533665
--- assertExpectedJournal(TestExamples.test_fp8_gemm)
534666
from __future__ import annotations
535667

@@ -762,6 +894,139 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
762894
_launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
763895
return out
764896

897+
--- assertExpectedJournal(TestExamples.test_jagged_mean_2d)
898+
from __future__ import annotations
899+
900+
import torch
901+
import triton
902+
import triton.language as tl
903+
904+
@triton.jit
905+
def _jagged_mean_kernel_2d_kernel(x_offsets, x_feature_counts, x_flat, out, out_stride_0, out_stride_1, x_feature_counts_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, max_M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
906+
pid_0 = tl.program_id(0)
907+
offset_0 = pid_0 * _BLOCK_SIZE_0
908+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
909+
mask_0 = indices_0 < num_rows
910+
starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
911+
v_0 = tl.full([], 1, tl.int32)
912+
v_1 = indices_0 + v_0
913+
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
914+
v_2 = ends - starts
915+
_mask_to = tl.where(mask_0, v_2, -9223372036854775808)
916+
max_nnz = tl.max(_mask_to, 0)
917+
feature_counts = tl.load(x_feature_counts + indices_0 * x_feature_counts_stride_0, mask_0, other=0)
918+
for offset_1 in tl.range(0, max_M.to(tl.int32), step=_BLOCK_SIZE_1):
919+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
920+
mask_1 = indices_1 < max_M
921+
feature_counts_copy = feature_counts
922+
max_nnz_copy = max_nnz
923+
starts_copy = starts
924+
v_2_copy = v_2
925+
feature_counts_copy_0 = feature_counts_copy
926+
max_nnz_copy_0 = max_nnz_copy
927+
starts_copy_0 = starts_copy
928+
v_2_copy_0 = v_2_copy
929+
subscript = feature_counts_copy_0[:, None]
930+
v_3 = indices_1[None, :]
931+
v_4 = v_3 < subscript
932+
row_sums = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
933+
for offset_2 in tl.range(0, max_nnz_copy_0.to(tl.int32), step=_BLOCK_SIZE_2):
934+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
935+
mask_2 = indices_2 < max_nnz_copy_0
936+
starts_copy_0_copy = starts_copy_0
937+
v_2_copy_0_copy = v_2_copy_0
938+
v_4_copy = v_4
939+
row_sums_copy = row_sums
940+
starts_copy_0_copy_0 = starts_copy_0_copy
941+
v_2_copy_0_copy_0 = v_2_copy_0_copy
942+
v_4_copy_0 = v_4_copy
943+
row_sums_copy_0 = row_sums_copy
944+
subscript_1 = starts_copy_0_copy_0[:, None]
945+
subscript_2 = indices_2[None, :]
946+
v_5 = subscript_2.to(tl.int64)
947+
v_6 = subscript_1 + v_5
948+
subscript_3 = v_6[:, :, None]
949+
v_7 = subscript_3 * max_M
950+
subscript_4 = indices_1[None, None, :]
951+
v_8 = subscript_4.to(tl.int64)
952+
v_9 = v_7 + v_8
953+
subscript_5 = indices_2[None, :]
954+
subscript_6 = v_2_copy_0_copy_0[:, None]
955+
v_10 = subscript_5.to(tl.int64)
956+
v_11 = v_10 < subscript_6
957+
subscript_7 = v_11[:, :, None]
958+
subscript_8 = v_4_copy_0[:, None, :]
959+
v_12 = subscript_7 & subscript_8
960+
x_slice = tl.load(x_flat + v_9 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
961+
sum_1 = tl.sum(x_slice, 1)
962+
row_sums = row_sums_copy_0 + sum_1
963+
v_14 = v_2_copy_0.to(tl.float32)
964+
nnz_expanded = v_14[:, None]
965+
v_15 = 0.0
966+
v_16 = nnz_expanded > v_15
967+
v_17 = row_sums / nnz_expanded
968+
v_18 = 0.0
969+
v_19 = v_18[None, None]
970+
v_20 = tl.where(v_16, v_17, v_19)
971+
v_21 = 0.0
972+
v_22 = v_21[None, None]
973+
v_23 = tl.where(v_4, v_20, v_22)
974+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_23, mask_0[:, None] & mask_1[None, :])
975+
976+
def jagged_mean_kernel_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M_tensor: torch.Tensor):
977+
"""
978+
Compute the mean of each row in a 2D jagged tensor with variable features per row.
979+
980+
Args
981+
----
982+
x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
983+
x_offsets : (num_rows + 1) tensor. Row i is the slice
984+
x_data[x_offsets[i] : x_offsets[i+1], :].
985+
x_feature_counts: (num_rows) tensor. Number of valid features for each row.
986+
max_M_tensor : Dummy tensor whose numel() gives max number of features.
987+
988+
Returns
989+
-------
990+
result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
991+
Invalid features (beyond x_feature_counts[i]) are set to 0.
992+
"""
993+
num_rows = x_offsets.size(0) - 1
994+
max_M = max_M_tensor.numel()
995+
out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
996+
x_flat = x_data.view(-1)
997+
_BLOCK_SIZE_0 = 16
998+
_BLOCK_SIZE_1 = 8
999+
_BLOCK_SIZE_2 = 16
1000+
_jagged_mean_kernel_2d_kernel[triton.cdiv(num_rows, _BLOCK_SIZE_0),](x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1001+
return out
1002+
1003+
def _jagged_mean_kernel_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M_tensor: torch.Tensor):
1004+
"""
1005+
Compute the mean of each row in a 2D jagged tensor with variable features per row.
1006+
1007+
Args
1008+
----
1009+
x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
1010+
x_offsets : (num_rows + 1) tensor. Row i is the slice
1011+
x_data[x_offsets[i] : x_offsets[i+1], :].
1012+
x_feature_counts: (num_rows) tensor. Number of valid features for each row.
1013+
max_M_tensor : Dummy tensor whose numel() gives max number of features.
1014+
1015+
Returns
1016+
-------
1017+
result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
1018+
Invalid features (beyond x_feature_counts[i]) are set to 0.
1019+
"""
1020+
num_rows = x_offsets.size(0) - 1
1021+
max_M = max_M_tensor.numel()
1022+
out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
1023+
x_flat = x_data.view(-1)
1024+
_BLOCK_SIZE_0 = 16
1025+
_BLOCK_SIZE_1 = 8
1026+
_BLOCK_SIZE_2 = 16
1027+
from helion.runtime.precompile_shim import make_precompiler
1028+
return make_precompiler(_jagged_mean_kernel_2d_kernel)(x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1029+
7651030
--- assertExpectedJournal(TestExamples.test_matmul)
7661031
from __future__ import annotations
7671032

test/test_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ def test_softmax_two_pass_block_ptr(self):
267267
)
268268
)
269269

270+
def test_cross_entropy(self):
271+
n, v = 128, 1000
272+
args = (
273+
torch.randn(n, v, device=DEVICE, dtype=torch.float32),
274+
torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long),
275+
)
276+
self.assertExpectedJournal(
277+
check_example(
278+
"cross_entropy",
279+
args,
280+
torch.nn.functional.cross_entropy(*args),
281+
)
282+
)
283+
270284
def test_rms_norm(self):
271285
args = (
272286
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)