@@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch
469
469
_launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
470
470
return out
471
471
472
+ --- assertExpectedJournal(TestExamples.test_cross_entropy)
473
+ from __future__ import annotations
474
+
475
+ import torch
476
+ import triton
477
+ import triton.language as tl
478
+ from torch._inductor.runtime.triton_helpers import math as tl_math
479
+ from helion.runtime import default_launcher as _default_launcher
480
+
481
+ @triton.jit
482
+ def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
483
+ pid_0 = tl.program_id(0)
484
+ offset_0 = pid_0
485
+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
486
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
487
+ mask_1 = indices_1 < v
488
+ labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
489
+ v_0 = v.to(tl.int32)
490
+ v_1 = indices_0 * v_0
491
+ v_2 = v_1.to(tl.int64)
492
+ v_3 = v_2 + labels_tile
493
+ logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None)
494
+ logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
495
+ _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
496
+ max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
497
+ v_4 = logits_rows - max_logits
498
+ v_5 = tl_math.exp(v_4)
499
+ _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0)
500
+ sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
501
+ squeeze = tl.reshape(max_logits, [1])
502
+ squeeze_1 = tl.reshape(sum_exp, [1])
503
+ v_6 = tl_math.log(squeeze_1)
504
+ v_7 = squeeze + v_6
505
+ v_8 = v_7 - logits_at_target
506
+ tl.store(losses + indices_0 * losses_stride_0, v_8, None)
507
+
508
+ def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
509
+ n, v = logits.shape
510
+ losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
511
+ logits_flat = logits.view(-1)
512
+ _RDIM_SIZE_1 = triton.next_power_of_2(v)
513
+ _launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
514
+ return losses.mean()
515
+
472
516
--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
473
517
from __future__ import annotations
474
518
0 commit comments