Skip to content

Commit 18a07a3

Browse files
authored
Remove default configs from examples (#295)
Default configs could lead to bad performance comparisons if people don't know how to trigger autotuning.
1 parent 148398a commit 18a07a3

File tree

7 files changed

+68
-92
lines changed

7 files changed

+68
-92
lines changed

examples/attention.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@
1313

1414

1515
@helion.kernel(
16-
config=helion.Config(
17-
# This config was autotuned on a 5090, it won't be fast for other cards
18-
block_sizes=[128, 16],
19-
loop_orders=[[0, 1]],
20-
l2_groupings=[2],
21-
num_warps=2,
22-
num_stages=3,
23-
indexing="pointer",
24-
),
2516
# Static shapes provides a speedup for attention
2617
static_shapes=True,
2718
)

examples/concatenate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import helion.language as hl
88

99

10-
@helion.kernel(
11-
config=helion.Config(block_size=[4, 1024], loop_order=[1, 0], num_warps=2)
12-
)
10+
@helion.kernel()
1311
def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1412
assert x.size(0) == y.size(0)
1513
out = torch.empty(

examples/embedding.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
import helion.language as hl
88

99

10-
@helion.kernel(
11-
config=helion.Config(
12-
block_sizes=[512, 32], loop_order=[0, 1], num_warps=8, indexing="block_ptr"
13-
)
14-
)
10+
@helion.kernel()
1511
def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1612
x_flat = x.reshape(-1) # collapse x into a single dimension
1713
_, embedding_dim = weight.size()

examples/jagged_dense_add.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
"""
2121

2222

23-
@helion.kernel(
24-
config=helion.Config(
25-
block_sizes=[1, 512, 512], num_warps=8, num_stages=4, indexing="block_ptr"
26-
)
27-
)
23+
@helion.kernel()
2824
def jagged_dense_add_2d(
2925
x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor
3026
) -> torch.Tensor:

examples/softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import helion.language as hl
88

99

10-
@helion.kernel(config={"block_size": 1})
10+
@helion.kernel()
1111
def softmax(x: torch.Tensor) -> torch.Tensor:
1212
n, _m = x.size()
1313
out = torch.empty_like(x)
@@ -17,7 +17,7 @@ def softmax(x: torch.Tensor) -> torch.Tensor:
1717

1818

1919
# This generates the same code as the above, but avoids using the pytorch softmax decomposition
20-
@helion.kernel(config={"block_size": 1})
20+
@helion.kernel()
2121
def softmax_decomposed(x: torch.Tensor) -> torch.Tensor:
2222
n, _m = x.size()
2323
out = torch.empty_like(x)
@@ -31,7 +31,7 @@ def softmax_decomposed(x: torch.Tensor) -> torch.Tensor:
3131

3232

3333
# This optimization does softmax in fewer passes, but is less numerically stable
34-
@helion.kernel(config={"block_sizes": [1, 128]})
34+
@helion.kernel()
3535
def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
3636
m, n = x.size()
3737
out = torch.empty_like(x)
@@ -58,7 +58,7 @@ def check(m: int, n: int) -> None:
5858
x = torch.randn([m, n], device="cuda", dtype=torch.float16)
5959
kernels = {
6060
"helion simple": softmax,
61-
"helion decomposed": softmax_decomposed,
61+
# "helion decomposed": softmax_decomposed,
6262
"helion two pass": softmax_two_pass,
6363
}
6464
run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))

examples/template_via_closure.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,6 @@
1414

1515

1616
@helion.kernel(
17-
# This was tuned on a 5090 and likely isn't optimal for other cards
18-
config=helion.Config(
19-
block_sizes=[64, 128, 64],
20-
loop_orders=[[0, 1]],
21-
l2_groupings=[2],
22-
num_warps=8,
23-
num_stages=5,
24-
indexing="pointer",
25-
),
2617
# static_shapes=True gives a performance boost for matmuls
2718
static_shapes=True,
2819
)

test/test_examples.expected

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,9 @@ from torch._inductor.runtime.triton_compat import libdevice
137137

138138
@triton.jit
139139
def _attention_kernel(q_view, k_view, v_view, out, q_in_size_1, k_view_stride_0, k_view_stride_1, k_view_stride_2, out_stride_0, out_stride_1, out_stride_2, q_view_stride_0, q_view_stride_1, q_view_stride_2, v_view_stride_0, v_view_stride_1, v_view_stride_2, m_dim, n_dim, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
140-
num_pid_m = q_in_size_1
141-
num_pid_n = tl.cdiv(m_dim, _BLOCK_SIZE_1)
142-
num_pid_in_group = 2 * num_pid_n
143-
group_id = tl.program_id(0) // num_pid_in_group
144-
first_pid_m = group_id * 2
145-
group_size_m = min(num_pid_m - first_pid_m, 2)
146-
pid_0 = first_pid_m + tl.program_id(0) % num_pid_in_group % group_size_m
147-
pid_1 = tl.program_id(0) % num_pid_in_group // group_size_m
140+
num_blocks_0 = q_in_size_1
141+
pid_0 = tl.program_id(0) % num_blocks_0
142+
pid_1 = tl.program_id(0) // num_blocks_0
148143
offset_0 = pid_0
149144
indices_0 = offset_0 + tl.zeros([1], tl.int32)
150145
offset_1 = pid_1 * _BLOCK_SIZE_1
@@ -204,10 +199,10 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
204199
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
205200
out = torch.empty_like(q_view)
206201
sm_scale = 1.0 / math.sqrt(head_dim)
207-
_BLOCK_SIZE_1 = 128
202+
_BLOCK_SIZE_1 = 32
208203
_RDIM_SIZE_2 = 64
209-
_BLOCK_SIZE_3 = 16
210-
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=2, num_stages=3)
204+
_BLOCK_SIZE_3 = 32
205+
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
211206
return out.view(q_in.size())
212207

213208
def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
@@ -221,11 +216,11 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
221216
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
222217
out = torch.empty_like(q_view)
223218
sm_scale = 1.0 / math.sqrt(head_dim)
224-
_BLOCK_SIZE_1 = 128
219+
_BLOCK_SIZE_1 = 32
225220
_RDIM_SIZE_2 = 64
226-
_BLOCK_SIZE_3 = 16
221+
_BLOCK_SIZE_3 = 32
227222
from helion.runtime.precompile_shim import make_precompiler
228-
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=2, num_stages=3)
223+
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
229224

230225
--- assertExpectedJournal(TestExamples.test_attention_pointer)
231226
from __future__ import annotations
@@ -381,16 +376,16 @@ import triton
381376
import triton.language as tl
382377

383378
@triton.jit
384-
def _concat2d_dim1_kernel(out, x, y, out_size_1, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
385-
num_blocks_0 = tl.cdiv(out_size_1, _BLOCK_SIZE_1)
379+
def _concat2d_dim1_kernel(x, out, y, out_size_1, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
380+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
386381
pid_0 = tl.program_id(0) % num_blocks_0
387382
pid_1 = tl.program_id(0) // num_blocks_0
388-
offset_1 = pid_0 * _BLOCK_SIZE_1
389-
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
390-
mask_1 = indices_1 < out_size_1
391-
offset_0 = pid_1 * _BLOCK_SIZE_0
383+
offset_0 = pid_0 * _BLOCK_SIZE_0
392384
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
393385
mask_0 = indices_0 < x_size_0
386+
offset_1 = pid_1 * _BLOCK_SIZE_1
387+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
388+
mask_1 = indices_1 < out_size_1
394389
v_0 = x_size_1.to(tl.int32)
395390
v_1 = indices_1 < v_0
396391
subscript = v_1[None, :]
@@ -410,18 +405,18 @@ def _concat2d_dim1_kernel(out, x, y, out_size_1, x_size_0, x_size_1, out_stride_
410405
def concat2d_dim1(x: torch.Tensor, y: torch.Tensor):
411406
assert x.size(0) == y.size(0)
412407
out = torch.empty([x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device)
413-
_BLOCK_SIZE_1 = 1024
414-
_BLOCK_SIZE_0 = 4
415-
_concat2d_dim1_kernel[triton.cdiv(out.size(1), _BLOCK_SIZE_1) * triton.cdiv(x.size(0), _BLOCK_SIZE_0),](out, x, y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=2, num_stages=3)
408+
_BLOCK_SIZE_0 = 32
409+
_BLOCK_SIZE_1 = 32
410+
_concat2d_dim1_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(out.size(1), _BLOCK_SIZE_1),](x, out, y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
416411
return out
417412

418413
def _concat2d_dim1_make_precompiler(x: torch.Tensor, y: torch.Tensor):
419414
assert x.size(0) == y.size(0)
420415
out = torch.empty([x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device)
421-
_BLOCK_SIZE_1 = 1024
422-
_BLOCK_SIZE_0 = 4
416+
_BLOCK_SIZE_0 = 32
417+
_BLOCK_SIZE_1 = 32
423418
from helion.runtime.precompile_shim import make_precompiler
424-
return make_precompiler(_concat2d_dim1_kernel)(out, x, y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=2, num_stages=3)
419+
return make_precompiler(_concat2d_dim1_kernel)(x, out, y, out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
425420

426421
--- assertExpectedJournal(TestExamples.test_concat_block_ptr)
427422
from __future__ import annotations
@@ -634,16 +629,18 @@ import triton
634629
import triton.language as tl
635630

636631
@triton.jit
637-
def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, out_size_0, out_size_1, x_offsets_size_0, y_size_0, y_size_1, out_stride_0, out_stride_1, x_data_stride_0, x_offsets_stride_0, y_stride_0, y_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
632+
def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, y_size_1, out_stride_0, out_stride_1, x_data_stride_0, x_offsets_stride_0, y_stride_0, y_stride_1, num_rows, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
638633
pid_0 = tl.program_id(0)
639-
offset_0 = pid_0
640-
indices_0 = offset_0 + tl.zeros([1], tl.int32)
641-
starts = tl.load(tl.make_block_ptr(x_offsets, [x_offsets_size_0], [x_offsets_stride_0], [offset_0], [1], [0]), boundary_check=[0], padding_option='zero')
634+
offset_0 = pid_0 * _BLOCK_SIZE_0
635+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
636+
mask_0 = indices_0 < num_rows
637+
starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
642638
v_0 = tl.full([], 1, tl.int32)
643639
v_1 = indices_0 + v_0
644-
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, None)
640+
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
645641
v_2 = ends - starts
646-
max_nnz = tl.max(v_2, 0)
642+
_mask_to = tl.where(mask_0, v_2, -9223372036854775808)
643+
max_nnz = tl.max(_mask_to, 0)
647644
for offset_1 in tl.range(0, max_nnz.to(tl.int32), _BLOCK_SIZE_1):
648645
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
649646
mask_1 = indices_1 < max_nnz
@@ -659,13 +656,15 @@ def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, out_size_0, out_size_
659656
subscript_3 = v_2_copy_0[:, None]
660657
v_5 = subscript_2.to(tl.int64)
661658
v_6 = v_5 < subscript_3
662-
x_slice = tl.load(x_data + v_4 * x_data_stride_0, mask_1[None, :] & v_6, other=0)
663-
load_1 = tl.load(tl.make_block_ptr(y, [y_size_0, y_size_1], [y_stride_0, y_stride_1], [offset_0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
659+
x_slice = tl.load(x_data + v_4 * x_data_stride_0, mask_0[:, None] & mask_1[None, :] & v_6, other=0)
660+
load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
664661
v_7 = load_1 + x_slice
665-
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), v_7, boundary_check=[0, 1])
662+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_7, mask_0[:, None] & mask_1[None, :])
666663
for offset_2 in tl.range(max_nnz.to(tl.int32), y_size_1.to(tl.int32), _BLOCK_SIZE_2):
667-
load = tl.load(tl.make_block_ptr(y, [y_size_0, y_size_1], [y_stride_0, y_stride_1], [offset_0, offset_2], [1, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero')
668-
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_2], [1, _BLOCK_SIZE_2], [1, 0]), load, boundary_check=[0, 1])
664+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
665+
mask_2 = indices_2 < y_size_1
666+
load = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_2[None, :] * y_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
667+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), load, mask_0[:, None] & mask_2[None, :])
669668

670669
def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor):
671670
"""
@@ -686,9 +685,10 @@ def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.
686685
num_rows = y.size(0)
687686
assert x_offsets.size(0) == num_rows + 1
688687
out = torch.zeros_like(y)
689-
_BLOCK_SIZE_1 = 512
690-
_BLOCK_SIZE_2 = 512
691-
_jagged_dense_add_2d_kernel[num_rows,](x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=4)
688+
_BLOCK_SIZE_0 = 16
689+
_BLOCK_SIZE_1 = 16
690+
_BLOCK_SIZE_2 = 16
691+
_jagged_dense_add_2d_kernel[triton.cdiv(num_rows, _BLOCK_SIZE_0),](x_offsets, x_data, y, out, y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), num_rows, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
692692
return out
693693

694694
def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor):
@@ -710,10 +710,11 @@ def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch
710710
num_rows = y.size(0)
711711
assert x_offsets.size(0) == num_rows + 1
712712
out = torch.zeros_like(y)
713-
_BLOCK_SIZE_1 = 512
714-
_BLOCK_SIZE_2 = 512
713+
_BLOCK_SIZE_0 = 16
714+
_BLOCK_SIZE_1 = 16
715+
_BLOCK_SIZE_2 = 16
715716
from helion.runtime.precompile_shim import make_precompiler
716-
return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=4)
717+
return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), num_rows, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
717718

718719
--- assertExpectedJournal(TestExamples.test_jagged_mean)
719720
from __future__ import annotations
@@ -1517,21 +1518,22 @@ from torch._inductor.runtime import triton_helpers
15171518
from torch._inductor.runtime.triton_helpers import math as tl_math
15181519

15191520
@triton.jit
1520-
def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, n, _BLOCK_SIZE_1: tl.constexpr):
1521+
def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
15211522
pid_0 = tl.program_id(0)
1522-
offset_0 = pid_0
1523-
indices_0 = offset_0 + tl.zeros([1], tl.int32)
1524-
mi = tl.full([1], float('-inf'), tl.float32)
1525-
di = tl.full([1], 0.0, tl.float32)
1523+
offset_0 = pid_0 * _BLOCK_SIZE_0
1524+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1525+
mask_0 = indices_0 < m
1526+
mi = tl.full([_BLOCK_SIZE_0], float('-inf'), tl.float32)
1527+
di = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
15261528
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
15271529
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
15281530
mask_1 = indices_2 < n
15291531
mi_copy = mi
15301532
di_copy = di
15311533
mi_copy_0 = mi_copy
15321534
di_copy_0 = di_copy
1533-
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[None, :], other=0)
1534-
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), values, float('-inf'))
1535+
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1536+
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], values, float('-inf'))
15351537
local_amax = tl.max(_mask_to, 1)
15361538
v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
15371539
v_1 = mi_copy_0 - v_0
@@ -1540,7 +1542,7 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
15401542
subscript = v_0[:, None]
15411543
v_4 = values - subscript
15421544
v_5 = tl_math.exp(v_4)
1543-
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), v_5, 0)
1545+
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, 0)
15441546
sum_1 = tl.sum(_mask_to_1, 1)
15451547
di = v_3 + sum_1
15461548
mi = v_0
@@ -1551,27 +1553,29 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
15511553
di_copy_1 = di
15521554
mi_copy_1_0 = mi_copy_1
15531555
di_copy_1_0 = di_copy_1
1554-
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_2[None, :], other=0)
1556+
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
15551557
subscript_1 = mi_copy_1_0[:, None]
15561558
v_7 = values - subscript_1
15571559
v_8 = tl_math.exp(v_7)
15581560
subscript_2 = di_copy_1_0[:, None]
15591561
v_9 = v_8 / subscript_2
1560-
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_9, mask_2[None, :])
1562+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_2[None, :])
15611563

15621564
def softmax_two_pass(x: torch.Tensor):
15631565
m, n = x.size()
15641566
out = torch.empty_like(x)
1565-
_BLOCK_SIZE_1 = 128
1566-
_softmax_two_pass_kernel[m,](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1567+
_BLOCK_SIZE_0 = 32
1568+
_BLOCK_SIZE_1 = 32
1569+
_softmax_two_pass_kernel[triton.cdiv(m, _BLOCK_SIZE_0),](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
15671570
return out
15681571

15691572
def _softmax_two_pass_make_precompiler(x: torch.Tensor):
15701573
m, n = x.size()
15711574
out = torch.empty_like(x)
1572-
_BLOCK_SIZE_1 = 128
1575+
_BLOCK_SIZE_0 = 32
1576+
_BLOCK_SIZE_1 = 32
15731577
from helion.runtime.precompile_shim import make_precompiler
1574-
return make_precompiler(_softmax_two_pass_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1578+
return make_precompiler(_softmax_two_pass_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
15751579

15761580
--- assertExpectedJournal(TestExamples.test_softmax_two_pass_block_ptr)
15771581
from __future__ import annotations

0 commit comments

Comments
 (0)