@@ -459,6 +459,57 @@ def _concat2d_dim1_make_precompiler(x: torch.Tensor, y: torch.Tensor):
459
459
from helion.runtime.precompile_shim import make_precompiler
460
460
return make_precompiler(_concat2d_dim1_kernel)(x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
461
461
462
+ --- assertExpectedJournal(TestExamples.test_cross_entropy)
463
+ from __future__ import annotations
464
+
465
+ import torch
466
+ import triton
467
+ import triton.language as tl
468
+ from torch._inductor.runtime.triton_helpers import math as tl_math
469
+
470
+ @triton.jit
471
+ def _cross_entropy_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
472
+ pid_0 = tl.program_id(0)
473
+ offset_0 = pid_0
474
+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
475
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
476
+ mask_1 = indices_1 < v
477
+ labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
478
+ base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None)
479
+ v_0 = base_indices_tile + labels_tile
480
+ logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None)
481
+ logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
482
+ _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
483
+ max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
484
+ v_1 = logits_rows - max_logits
485
+ v_2 = tl_math.exp(v_1)
486
+ _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0)
487
+ sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
488
+ squeeze = tl.reshape(max_logits, [1])
489
+ squeeze_1 = tl.reshape(sum_exp, [1])
490
+ v_3 = tl_math.log(squeeze_1)
491
+ v_4 = squeeze + v_3
492
+ v_5 = v_4 - logits_at_target
493
+ tl.store(losses + indices_0 * losses_stride_0, v_5, None)
494
+
495
+ def cross_entropy(logits: torch.Tensor, labels: torch.Tensor):
496
+ n, v = logits.shape
497
+ losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
498
+ base_indices = torch.arange(n, device=logits.device) * v
499
+ logits_flat = logits.view(-1)
500
+ _RDIM_SIZE_1 = triton.next_power_of_2(v)
501
+ _cross_entropy_kernel[n,](labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
502
+ return losses.mean()
503
+
504
+ def _cross_entropy_make_precompiler(logits: torch.Tensor, labels: torch.Tensor):
505
+ n, v = logits.shape
506
+ losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
507
+ base_indices = torch.arange(n, device=logits.device) * v
508
+ logits_flat = logits.view(-1)
509
+ _RDIM_SIZE_1 = triton.next_power_of_2(v)
510
+ from helion.runtime.precompile_shim import make_precompiler
511
+ return make_precompiler(_cross_entropy_kernel)(labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
512
+
462
513
--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
463
514
from __future__ import annotations
464
515
0 commit comments