From f084ead5590b168e14df5afbd782637d4007b3ab Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 13:05:50 -0700 Subject: [PATCH] Add cross_entropy example and unit test stack-info: PR: https://github.com/pytorch-labs/helion/pull/320, branch: yf225/stack/28 --- examples/cross_entropy.py | 75 +++++++++++++++++++++++++++++++++++++ test/test_examples.expected | 44 ++++++++++++++++++++++ test/test_examples.py | 14 +++++++ 3 files changed, 133 insertions(+) create mode 100644 examples/cross_entropy.py diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py new file mode 100644 index 00000000..28f36cd1 --- /dev/null +++ b/examples/cross_entropy.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import os + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + +# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable +if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": + # Low memory configuration + TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"} + + +@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) +def cross_entropy( + logits: torch.Tensor, # [N, V] input logits + labels: torch.Tensor, # [N] target labels +) -> torch.Tensor: + n, v = logits.shape + losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) + + # Flatten logits once at the beginning + logits_flat = logits.view(-1) + + for tile_n in hl.tile(n): + # Get data for this tile + labels_tile = labels[tile_n] # [tile_size] + base_indices_tile = tile_n.index * v # [tile_size] + + # Compute the actual flat indices by adding the label offset + flat_indices = base_indices_tile + labels_tile + + # Load the logits at the target indices + logits_at_target = hl.load(logits_flat, [flat_indices]) + + # Compute log_softmax for numerical stability + # Load the full rows for this tile + logits_rows = logits[tile_n, :] # [tile_size, V] + + # Compute log-sum-exp + max_logits = torch.amax(logits_rows, dim=-1, keepdim=True) + shifted = logits_rows - max_logits + exp_shifted = torch.exp(shifted) + sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True) + log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1)) + + # Cross entropy loss: log_sum_exp - logit_at_target + losses[tile_n] = log_sum_exp - logits_at_target + + return losses.mean() + + +def main() -> None: + """Run cross entropy benchmark with different input sizes.""" + # Test with moderate size + n, v = 128, 1000 + logits = torch.randn(n, v, device="cuda", dtype=torch.float32) + labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long) + + run_example( + cross_entropy, + torch.nn.functional.cross_entropy, + (logits, labels), + kernel_name="helion", + baseline_name="torch", + rtol=1e-4, + atol=1e-4, + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index 32dc18c1..cc597ddd 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch _launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_cross_entropy) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < v + labels_tile = tl.load(labels + indices_0 * labels_stride_0, None) + v_0 = v.to(tl.int32) + v_1 = indices_0 * v_0 + v_2 = v_1.to(tl.int64) + v_3 = v_2 + labels_tile + logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None) + logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0) + _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf')) + max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1]) + v_4 = logits_rows - max_logits + v_5 = tl_math.exp(v_4) + _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0) + sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1]) + squeeze = tl.reshape(max_logits, [1]) + squeeze_1 = tl.reshape(sum_exp, [1]) + v_6 = tl_math.log(squeeze_1) + v_7 = squeeze + v_6 + v_8 = v_7 - logits_at_target + tl.store(losses + indices_0 * losses_stride_0, v_8, None) + +def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher): + n, v = logits.shape + losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) + logits_flat = logits.view(-1) + _RDIM_SIZE_1 = triton.next_power_of_2(v) + _launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return losses.mean() + --- assertExpectedJournal(TestExamples.test_embedding_block_ptr) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 7f036854..16bf049d 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -267,6 +267,20 @@ def test_softmax_two_pass_block_ptr(self): ) ) + def test_cross_entropy(self): + n, v = 128, 1000 + args = ( + torch.randn(n, v, device=DEVICE, dtype=torch.float32), + torch.randint(0, v, (n,), device=DEVICE, dtype=torch.long), + ) + self.assertExpectedJournal( + check_example( + "cross_entropy", + args, + torch.nn.functional.cross_entropy(*args), + ) + ) + def test_rms_norm(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float16),