@@ -469,6 +469,50 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launch
469
469
_launcher(_concat2d_dim1_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),), x, out, y, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
470
470
return out
471
471
472
+ --- assertExpectedJournal(TestExamples.test_cross_entropy)
473
+ from __future__ import annotations
474
+
475
+ import torch
476
+ import triton
477
+ import triton.language as tl
478
+ from torch._inductor.runtime.triton_helpers import math as tl_math
479
+ from helion.runtime import default_launcher as _default_launcher
480
+
481
+ @triton.jit
482
+ def _cross_entropy_kernel(labels, logits_flat, logits, losses, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
483
+ pid_0 = tl.program_id(0)
484
+ offset_0 = pid_0
485
+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
486
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
487
+ mask_1 = indices_1 < v
488
+ labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
489
+ v_0 = v.to(tl.int32)
490
+ v_1 = indices_0 * v_0
491
+ v_2 = v_1.to(tl.int64)
492
+ v_3 = v_2 + labels_tile
493
+ logits_at_target = tl.load(logits_flat + v_3 * logits_flat_stride_0, None)
494
+ logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
495
+ _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
496
+ max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
497
+ v_4 = logits_rows - max_logits
498
+ v_5 = tl_math.exp(v_4)
499
+ _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_5, 0)
500
+ sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
501
+ squeeze = tl.reshape(max_logits, [1])
502
+ squeeze_1 = tl.reshape(sum_exp, [1])
503
+ v_6 = tl_math.log(squeeze_1)
504
+ v_7 = squeeze + v_6
505
+ v_8 = v_7 - logits_at_target
506
+ tl.store(losses + indices_0 * losses_stride_0, v_8, None)
507
+
508
+ def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
509
+ n, v = logits.shape
510
+ losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
511
+ logits_flat = logits.view(-1)
512
+ _RDIM_SIZE_1 = triton.next_power_of_2(v)
513
+ _launcher(_cross_entropy_kernel, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
514
+ return losses.mean()
515
+
472
516
--- assertExpectedJournal(TestExamples.test_embedding_block_ptr)
473
517
from __future__ import annotations
474
518
@@ -530,6 +574,94 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
530
574
_launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
531
575
return out.view(*x.size(), embedding_dim)
532
576
577
+ --- assertExpectedJournal(TestExamples.test_fp8_attention)
578
+ from __future__ import annotations
579
+
580
+ import math
581
+ import torch
582
+ import triton
583
+ import triton.language as tl
584
+ from torch._inductor.runtime import triton_helpers
585
+ from torch._inductor.runtime.triton_compat import libdevice
586
+
587
+ @triton.jit
588
+ def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
589
+ pid_0 = tl.program_id(0)
590
+ offset_0 = pid_0
591
+ indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
592
+ for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
593
+ indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
594
+ m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
595
+ l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
596
+ acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
597
+ q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
598
+ for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
599
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
600
+ q_tile_copy = q_tile
601
+ m_i_copy = m_i
602
+ l_i_copy = l_i
603
+ acc_copy = acc
604
+ q_tile_copy_0 = q_tile_copy
605
+ m_i_copy_0 = m_i_copy
606
+ l_i_copy_0 = l_i_copy
607
+ acc_copy_0 = acc_copy
608
+ k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
609
+ k_tile_t = tl.permute(k_tile, [1, 0])
610
+ mm = tl.dot(q_tile_copy_0, k_tile_t)
611
+ v_0 = mm.to(tl.float32)
612
+ v_1 = 0.18033688
613
+ v_2 = v_0 * v_1
614
+ qk_max = tl.max(v_2, 1)
615
+ v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
616
+ subscript = v_3[:, None]
617
+ v_4 = v_2 - subscript
618
+ v_5 = libdevice.exp2(v_4)
619
+ l_ij = tl.sum(v_5, 1)
620
+ v_6 = m_i_copy_0 - v_3
621
+ v_7 = libdevice.exp2(v_6)
622
+ v_8 = l_i_copy_0 * v_7
623
+ l_i = v_8 + l_ij
624
+ subscript_1 = v_7[:, None]
625
+ v_10 = acc_copy_0 * subscript_1
626
+ v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
627
+ v_11 = v_5.to(tl.float8e5)
628
+ v_t = tl.permute(v_tile, [1, 0])
629
+ mm_1 = tl.dot(v_11, v_t)
630
+ v_12 = mm_1.to(tl.float32)
631
+ acc = v_10 + v_12
632
+ m_i = v_3
633
+ subscript_2 = l_i[:, None]
634
+ v_14 = acc / subscript_2
635
+ tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None)
636
+
637
+ def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
638
+ """FP8 attention kernel processing batch*heads in parallel."""
639
+ batch_heads = q.size(0)
640
+ seq_len = q.size(1)
641
+ head_dim = q.size(2)
642
+ out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
643
+ sm_scale = 1.0 / math.sqrt(float(head_dim))
644
+ sm_scale = sm_scale * 1.44269504
645
+ _RDIM_SIZE_2 = 64
646
+ _BLOCK_SIZE_1 = 64
647
+ _BLOCK_SIZE_3 = 64
648
+ _fp8_attention_kernel_kernel[8,](q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
649
+ return out
650
+
651
+ def _fp8_attention_kernel_make_precompiler(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
652
+ """FP8 attention kernel processing batch*heads in parallel."""
653
+ batch_heads = q.size(0)
654
+ seq_len = q.size(1)
655
+ head_dim = q.size(2)
656
+ out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
657
+ sm_scale = 1.0 / math.sqrt(float(head_dim))
658
+ sm_scale = sm_scale * 1.44269504
659
+ _RDIM_SIZE_2 = 64
660
+ _BLOCK_SIZE_1 = 64
661
+ _BLOCK_SIZE_3 = 64
662
+ from helion.runtime.precompile_shim import make_precompiler
663
+ return make_precompiler(_fp8_attention_kernel_kernel)(q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
664
+
533
665
--- assertExpectedJournal(TestExamples.test_fp8_gemm)
534
666
from __future__ import annotations
535
667
@@ -762,6 +894,139 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
762
894
_launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
763
895
return out
764
896
897
+ --- assertExpectedJournal(TestExamples.test_jagged_mean_2d)
898
+ from __future__ import annotations
899
+
900
+ import torch
901
+ import triton
902
+ import triton.language as tl
903
+
904
+ @triton.jit
905
+ def _jagged_mean_kernel_2d_kernel(x_offsets, x_feature_counts, x_flat, out, out_stride_0, out_stride_1, x_feature_counts_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, max_M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
906
+ pid_0 = tl.program_id(0)
907
+ offset_0 = pid_0 * _BLOCK_SIZE_0
908
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
909
+ mask_0 = indices_0 < num_rows
910
+ starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
911
+ v_0 = tl.full([], 1, tl.int32)
912
+ v_1 = indices_0 + v_0
913
+ ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
914
+ v_2 = ends - starts
915
+ _mask_to = tl.where(mask_0, v_2, -9223372036854775808)
916
+ max_nnz = tl.max(_mask_to, 0)
917
+ feature_counts = tl.load(x_feature_counts + indices_0 * x_feature_counts_stride_0, mask_0, other=0)
918
+ for offset_1 in tl.range(0, max_M.to(tl.int32), step=_BLOCK_SIZE_1):
919
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
920
+ mask_1 = indices_1 < max_M
921
+ feature_counts_copy = feature_counts
922
+ max_nnz_copy = max_nnz
923
+ starts_copy = starts
924
+ v_2_copy = v_2
925
+ feature_counts_copy_0 = feature_counts_copy
926
+ max_nnz_copy_0 = max_nnz_copy
927
+ starts_copy_0 = starts_copy
928
+ v_2_copy_0 = v_2_copy
929
+ subscript = feature_counts_copy_0[:, None]
930
+ v_3 = indices_1[None, :]
931
+ v_4 = v_3 < subscript
932
+ row_sums = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
933
+ for offset_2 in tl.range(0, max_nnz_copy_0.to(tl.int32), step=_BLOCK_SIZE_2):
934
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
935
+ mask_2 = indices_2 < max_nnz_copy_0
936
+ starts_copy_0_copy = starts_copy_0
937
+ v_2_copy_0_copy = v_2_copy_0
938
+ v_4_copy = v_4
939
+ row_sums_copy = row_sums
940
+ starts_copy_0_copy_0 = starts_copy_0_copy
941
+ v_2_copy_0_copy_0 = v_2_copy_0_copy
942
+ v_4_copy_0 = v_4_copy
943
+ row_sums_copy_0 = row_sums_copy
944
+ subscript_1 = starts_copy_0_copy_0[:, None]
945
+ subscript_2 = indices_2[None, :]
946
+ v_5 = subscript_2.to(tl.int64)
947
+ v_6 = subscript_1 + v_5
948
+ subscript_3 = v_6[:, :, None]
949
+ v_7 = subscript_3 * max_M
950
+ subscript_4 = indices_1[None, None, :]
951
+ v_8 = subscript_4.to(tl.int64)
952
+ v_9 = v_7 + v_8
953
+ subscript_5 = indices_2[None, :]
954
+ subscript_6 = v_2_copy_0_copy_0[:, None]
955
+ v_10 = subscript_5.to(tl.int64)
956
+ v_11 = v_10 < subscript_6
957
+ subscript_7 = v_11[:, :, None]
958
+ subscript_8 = v_4_copy_0[:, None, :]
959
+ v_12 = subscript_7 & subscript_8
960
+ x_slice = tl.load(x_flat + v_9 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
961
+ sum_1 = tl.sum(x_slice, 1)
962
+ row_sums = row_sums_copy_0 + sum_1
963
+ v_14 = v_2_copy_0.to(tl.float32)
964
+ nnz_expanded = v_14[:, None]
965
+ v_15 = 0.0
966
+ v_16 = nnz_expanded > v_15
967
+ v_17 = row_sums / nnz_expanded
968
+ v_18 = 0.0
969
+ v_19 = v_18[None, None]
970
+ v_20 = tl.where(v_16, v_17, v_19)
971
+ v_21 = 0.0
972
+ v_22 = v_21[None, None]
973
+ v_23 = tl.where(v_4, v_20, v_22)
974
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_23, mask_0[:, None] & mask_1[None, :])
975
+
976
+ def jagged_mean_kernel_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M_tensor: torch.Tensor):
977
+ """
978
+ Compute the mean of each row in a 2D jagged tensor with variable features per row.
979
+
980
+ Args
981
+ ----
982
+ x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
983
+ x_offsets : (num_rows + 1) tensor. Row i is the slice
984
+ x_data[x_offsets[i] : x_offsets[i+1], :].
985
+ x_feature_counts: (num_rows) tensor. Number of valid features for each row.
986
+ max_M_tensor : Dummy tensor whose numel() gives max number of features.
987
+
988
+ Returns
989
+ -------
990
+ result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
991
+ Invalid features (beyond x_feature_counts[i]) are set to 0.
992
+ """
993
+ num_rows = x_offsets.size(0) - 1
994
+ max_M = max_M_tensor.numel()
995
+ out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
996
+ x_flat = x_data.view(-1)
997
+ _BLOCK_SIZE_0 = 16
998
+ _BLOCK_SIZE_1 = 8
999
+ _BLOCK_SIZE_2 = 16
1000
+ _jagged_mean_kernel_2d_kernel[triton.cdiv(num_rows, _BLOCK_SIZE_0),](x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1001
+ return out
1002
+
1003
+ def _jagged_mean_kernel_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M_tensor: torch.Tensor):
1004
+ """
1005
+ Compute the mean of each row in a 2D jagged tensor with variable features per row.
1006
+
1007
+ Args
1008
+ ----
1009
+ x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
1010
+ x_offsets : (num_rows + 1) tensor. Row i is the slice
1011
+ x_data[x_offsets[i] : x_offsets[i+1], :].
1012
+ x_feature_counts: (num_rows) tensor. Number of valid features for each row.
1013
+ max_M_tensor : Dummy tensor whose numel() gives max number of features.
1014
+
1015
+ Returns
1016
+ -------
1017
+ result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
1018
+ Invalid features (beyond x_feature_counts[i]) are set to 0.
1019
+ """
1020
+ num_rows = x_offsets.size(0) - 1
1021
+ max_M = max_M_tensor.numel()
1022
+ out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
1023
+ x_flat = x_data.view(-1)
1024
+ _BLOCK_SIZE_0 = 16
1025
+ _BLOCK_SIZE_1 = 8
1026
+ _BLOCK_SIZE_2 = 16
1027
+ from helion.runtime.precompile_shim import make_precompiler
1028
+ return make_precompiler(_jagged_mean_kernel_2d_kernel)(x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1029
+
765
1030
--- assertExpectedJournal(TestExamples.test_matmul)
766
1031
from __future__ import annotations
767
1032
0 commit comments