@@ -137,14 +137,9 @@ from torch._inductor.runtime.triton_compat import libdevice
137
137
138
138
@triton.jit
139
139
def _attention_kernel(q_view, k_view, v_view, out, q_in_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
140
- num_pid_m = q_in_size_1
141
- num_pid_n = tl.cdiv(m_dim, _BLOCK_SIZE_1)
142
- num_pid_in_group = 2 * num_pid_n
143
- group_id = tl.program_id(0) // num_pid_in_group
144
- first_pid_m = group_id * 2
145
- group_size_m = min(num_pid_m - first_pid_m, 2)
146
- pid_0 = first_pid_m + tl.program_id(0) % num_pid_in_group % group_size_m
147
- pid_1 = tl.program_id(0) % num_pid_in_group // group_size_m
140
+ num_blocks_0 = q_in_size_1
141
+ pid_0 = tl.program_id(0) % num_blocks_0
142
+ pid_1 = tl.program_id(0) // num_blocks_0
148
143
offset_0 = pid_0
149
144
indices_0 = offset_0 + tl.zeros([1], tl.int32)
150
145
offset_1 = pid_1 * _BLOCK_SIZE_1
@@ -204,10 +199,10 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
204
199
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
205
200
out = torch.empty_like(q_view)
206
201
sm_scale = 1.0 / math.sqrt(head_dim)
207
- _BLOCK_SIZE_1 = 128
202
+ _BLOCK_SIZE_1 = 32
208
203
_RDIM_SIZE_2 = 64
209
- _BLOCK_SIZE_3 = 16
210
- _attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=2 , num_stages=3)
204
+ _BLOCK_SIZE_3 = 32
205
+ _attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4 , num_stages=3)
211
206
return out.view(q_in.size())
212
207
213
208
def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
@@ -221,11 +216,11 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
221
216
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
222
217
out = torch.empty_like(q_view)
223
218
sm_scale = 1.0 / math.sqrt(head_dim)
224
- _BLOCK_SIZE_1 = 128
219
+ _BLOCK_SIZE_1 = 32
225
220
_RDIM_SIZE_2 = 64
226
- _BLOCK_SIZE_3 = 16
221
+ _BLOCK_SIZE_3 = 32
227
222
from helion.runtime.precompile_shim import make_precompiler
228
- return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=2 , num_stages=3)
223
+ return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4 , num_stages=3)
229
224
230
225
--- assertExpectedJournal(TestExamples.test_attention_pointer)
231
226
from __future__ import annotations
@@ -381,16 +376,16 @@ import triton
381
376
import triton.language as tl
382
377
383
378
@triton.jit
384
- def _concat2d_dim1_kernel(out, x , y, out_size_1, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_1 : tl.constexpr, _BLOCK_SIZE_0 : tl.constexpr):
385
- num_blocks_0 = tl.cdiv(out_size_1, _BLOCK_SIZE_1 )
379
+ def _concat2d_dim1_kernel(x, out , y, out_size_1, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0 : tl.constexpr, _BLOCK_SIZE_1 : tl.constexpr):
380
+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0 )
386
381
pid_0 = tl.program_id(0) % num_blocks_0
387
382
pid_1 = tl.program_id(0) // num_blocks_0
388
- offset_1 = pid_0 * _BLOCK_SIZE_1
389
- indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
390
- mask_1 = indices_1 < out_size_1
391
- offset_0 = pid_1 * _BLOCK_SIZE_0
383
+ offset_0 = pid_0 * _BLOCK_SIZE_0
392
384
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
393
385
mask_0 = indices_0 < x_size_0
386
+ offset_1 = pid_1 * _BLOCK_SIZE_1
387
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
388
+ mask_1 = indices_1 < out_size_1
394
389
v_0 = x_size_1.to(tl.int32)
395
390
v_1 = indices_1 < v_0
396
391
subscript = v_1[None, :]
@@ -410,18 +405,18 @@ def _concat2d_dim1_kernel(out, x, y, out_size_1, x_size_0, x_size_1, out_stride_
410
405
def concat2d_dim1(x: torch.Tensor, y: torch.Tensor):
411
406
assert x.size(0) == y.size(0)
412
407
out = torch.empty([x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device)
413
- _BLOCK_SIZE_1 = 1024
414
- _BLOCK_SIZE_0 = 4
415
- _concat2d_dim1_kernel[triton.cdiv(out .size(1 ), _BLOCK_SIZE_1 ) * triton.cdiv(x .size(0 ), _BLOCK_SIZE_0 ),](out, x , y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0 , num_warps=2 , num_stages=3)
408
+ _BLOCK_SIZE_0 = 32
409
+ _BLOCK_SIZE_1 = 32
410
+ _concat2d_dim1_kernel[triton.cdiv(x .size(0 ), _BLOCK_SIZE_0 ) * triton.cdiv(out .size(1 ), _BLOCK_SIZE_1 ),](x, out , y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1 , num_warps=4 , num_stages=3)
416
411
return out
417
412
418
413
def _concat2d_dim1_make_precompiler(x: torch.Tensor, y: torch.Tensor):
419
414
assert x.size(0) == y.size(0)
420
415
out = torch.empty([x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device)
421
- _BLOCK_SIZE_1 = 1024
422
- _BLOCK_SIZE_0 = 4
416
+ _BLOCK_SIZE_0 = 32
417
+ _BLOCK_SIZE_1 = 32
423
418
from helion.runtime.precompile_shim import make_precompiler
424
- return make_precompiler(_concat2d_dim1_kernel)(out, x , y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0 , num_warps=2 , num_stages=3)
419
+ return make_precompiler(_concat2d_dim1_kernel)(x, out , y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1 , num_warps=4 , num_stages=3)
425
420
426
421
--- assertExpectedJournal(TestExamples.test_concat_block_ptr)
427
422
from __future__ import annotations
@@ -634,16 +629,18 @@ import triton
634
629
import triton.language as tl
635
630
636
631
@triton.jit
637
- def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, out_size_0, out_size_1, x_offsets_size_0, y_size_0, y_size_1, out_stride_0, out_stride_1, x_data_stride_0, x_offsets_stride_0, y_stride_0, y_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
632
+ def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, y_size_1, out_stride_0, out_stride_1, x_data_stride_0, x_offsets_stride_0, y_stride_0, y_stride_1, num_rows, _BLOCK_SIZE_0: tl.constexpr , _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
638
633
pid_0 = tl.program_id(0)
639
- offset_0 = pid_0
640
- indices_0 = offset_0 + tl.zeros([1], tl.int32)
641
- starts = tl.load(tl.make_block_ptr(x_offsets, [x_offsets_size_0], [x_offsets_stride_0], [offset_0], [1], [0]), boundary_check=[0], padding_option='zero')
634
+ offset_0 = pid_0 * _BLOCK_SIZE_0
635
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
636
+ mask_0 = indices_0 < num_rows
637
+ starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
642
638
v_0 = tl.full([], 1, tl.int32)
643
639
v_1 = indices_0 + v_0
644
- ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, None )
640
+ ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0 )
645
641
v_2 = ends - starts
646
- max_nnz = tl.max(v_2, 0)
642
+ _mask_to = tl.where(mask_0, v_2, -9223372036854775808)
643
+ max_nnz = tl.max(_mask_to, 0)
647
644
for offset_1 in tl.range(0, max_nnz.to(tl.int32), _BLOCK_SIZE_1):
648
645
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
649
646
mask_1 = indices_1 < max_nnz
@@ -659,13 +656,15 @@ def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, out_size_0, out_size_
659
656
subscript_3 = v_2_copy_0[:, None]
660
657
v_5 = subscript_2.to(tl.int64)
661
658
v_6 = v_5 < subscript_3
662
- x_slice = tl.load(x_data + v_4 * x_data_stride_0, mask_1[None, :] & v_6, other=0)
663
- load_1 = tl.load(tl.make_block_ptr(y, [y_size_0, y_size_1], [ y_stride_0, y_stride_1], [offset_0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1 ], padding_option='zero' )
659
+ x_slice = tl.load(x_data + v_4 * x_data_stride_0, mask_0[:, None] & mask_1[None, :] & v_6, other=0)
660
+ load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, : ], other=0 )
664
661
v_7 = load_1 + x_slice
665
- tl.store(tl.make_block_ptr( out, [out_size_0, out_size_1], [ out_stride_0, out_stride_1], [offset_0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), v_7, boundary_check=[0, 1 ])
662
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_7, mask_0[:, None] & mask_1[None, : ])
666
663
for offset_2 in tl.range(max_nnz.to(tl.int32), y_size_1.to(tl.int32), _BLOCK_SIZE_2):
667
- load = tl.load(tl.make_block_ptr(y, [y_size_0, y_size_1], [y_stride_0, y_stride_1], [offset_0, offset_2], [1, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
668
- tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_2], [1, _BLOCK_SIZE_2], [1, 0]), load, boundary_check=[0, 1])
664
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
665
+ mask_2 = indices_2 < y_size_1
666
+ load = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_2[None, :] * y_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
667
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), load, mask_0[:, None] & mask_2[None, :])
669
668
670
669
def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor):
671
670
"""
@@ -686,9 +685,10 @@ def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.
686
685
num_rows = y.size(0)
687
686
assert x_offsets.size(0) == num_rows + 1
688
687
out = torch.zeros_like(y)
689
- _BLOCK_SIZE_1 = 512
690
- _BLOCK_SIZE_2 = 512
691
- _jagged_dense_add_2d_kernel[num_rows,](x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=4)
688
+ _BLOCK_SIZE_0 = 16
689
+ _BLOCK_SIZE_1 = 16
690
+ _BLOCK_SIZE_2 = 16
691
+ _jagged_dense_add_2d_kernel[triton.cdiv(num_rows, _BLOCK_SIZE_0),](x_offsets, x_data, y, out, y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), num_rows, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
692
692
return out
693
693
694
694
def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor):
@@ -710,10 +710,11 @@ def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch
710
710
num_rows = y.size(0)
711
711
assert x_offsets.size(0) == num_rows + 1
712
712
out = torch.zeros_like(y)
713
- _BLOCK_SIZE_1 = 512
714
- _BLOCK_SIZE_2 = 512
713
+ _BLOCK_SIZE_0 = 16
714
+ _BLOCK_SIZE_1 = 16
715
+ _BLOCK_SIZE_2 = 16
715
716
from helion.runtime.precompile_shim import make_precompiler
716
- return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size( 1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8 , num_stages=4 )
717
+ return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), num_rows, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4 , num_stages=3 )
717
718
718
719
--- assertExpectedJournal(TestExamples.test_jagged_mean)
719
720
from __future__ import annotations
@@ -1517,21 +1518,22 @@ from torch._inductor.runtime import triton_helpers
1517
1518
from torch._inductor.runtime.triton_helpers import math as tl_math
1518
1519
1519
1520
@triton.jit
1520
- def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, n , _BLOCK_SIZE_1: tl.constexpr):
1521
+ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr , _BLOCK_SIZE_1: tl.constexpr):
1521
1522
pid_0 = tl.program_id(0)
1522
- offset_0 = pid_0
1523
- indices_0 = offset_0 + tl.zeros([1], tl.int32)
1524
- mi = tl.full([1], float('-inf'), tl.float32)
1525
- di = tl.full([1], 0.0, tl.float32)
1523
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1524
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1525
+ mask_0 = indices_0 < m
1526
+ mi = tl.full([_BLOCK_SIZE_0], float('-inf'), tl.float32)
1527
+ di = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1526
1528
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
1527
1529
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1528
1530
mask_1 = indices_2 < n
1529
1531
mi_copy = mi
1530
1532
di_copy = di
1531
1533
mi_copy_0 = mi_copy
1532
1534
di_copy_0 = di_copy
1533
- values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[None, :], other=0)
1534
- _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]) , values, float('-inf'))
1535
+ values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1536
+ _mask_to = tl.where(mask_0[:, None] & mask_1[None, :] , values, float('-inf'))
1535
1537
local_amax = tl.max(_mask_to, 1)
1536
1538
v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
1537
1539
v_1 = mi_copy_0 - v_0
@@ -1540,7 +1542,7 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
1540
1542
subscript = v_0[:, None]
1541
1543
v_4 = values - subscript
1542
1544
v_5 = tl_math.exp(v_4)
1543
- _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]) , v_5, 0)
1545
+ _mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :] , v_5, 0)
1544
1546
sum_1 = tl.sum(_mask_to_1, 1)
1545
1547
di = v_3 + sum_1
1546
1548
mi = v_0
@@ -1551,27 +1553,29 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
1551
1553
di_copy_1 = di
1552
1554
mi_copy_1_0 = mi_copy_1
1553
1555
di_copy_1_0 = di_copy_1
1554
- values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_2[None, :], other=0)
1556
+ values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1555
1557
subscript_1 = mi_copy_1_0[:, None]
1556
1558
v_7 = values - subscript_1
1557
1559
v_8 = tl_math.exp(v_7)
1558
1560
subscript_2 = di_copy_1_0[:, None]
1559
1561
v_9 = v_8 / subscript_2
1560
- tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_9, mask_2[None, :])
1562
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_2[None, :])
1561
1563
1562
1564
def softmax_two_pass(x: torch.Tensor):
1563
1565
m, n = x.size()
1564
1566
out = torch.empty_like(x)
1565
- _BLOCK_SIZE_1 = 128
1566
- _softmax_two_pass_kernel[m,](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1567
+ _BLOCK_SIZE_0 = 32
1568
+ _BLOCK_SIZE_1 = 32
1569
+ _softmax_two_pass_kernel[triton.cdiv(m, _BLOCK_SIZE_0),](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1567
1570
return out
1568
1571
1569
1572
def _softmax_two_pass_make_precompiler(x: torch.Tensor):
1570
1573
m, n = x.size()
1571
1574
out = torch.empty_like(x)
1572
- _BLOCK_SIZE_1 = 128
1575
+ _BLOCK_SIZE_0 = 32
1576
+ _BLOCK_SIZE_1 = 32
1573
1577
from helion.runtime.precompile_shim import make_precompiler
1574
- return make_precompiler(_softmax_two_pass_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n , _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1578
+ return make_precompiler(_softmax_two_pass_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0 , _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1575
1579
1576
1580
--- assertExpectedJournal(TestExamples.test_softmax_two_pass_block_ptr)
1577
1581
from __future__ import annotations
0 commit comments