Skip to content

Commit 80510e0

Browse files
authored
Re-enable unit test for moe_matmul_ogs example; skip in fbcode (#123)
1 parent 31acab6 commit 80510e0

File tree

2 files changed

+84
-76
lines changed

2 files changed

+84
-76
lines changed

examples/moe_matmul_ogs.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -39,49 +39,51 @@ def moe_matmul_ogs(
3939
start = expert_token_offsets[e_idx] # Starting index in sorted token array
4040
num_tokens = expert_token_counts[e_idx] # Number of tokens for this expert
4141

42-
# Tile over tokens and output features for this expert
43-
for tile_t, tile_n in hl.tile([max_T_per_expert, N]):
44-
# Get local token offsets for this tile
45-
# (i.e. the tile's corresponding chunk in [0 .. max_T_per_expert-1] token range)
46-
local_token_offsets = tile_t.index # [BLOCK_T]
47-
48-
# Create mask for valid tokens (some tiles may be partially filled)
49-
token_valid = local_token_offsets < num_tokens # bool[BLOCK_T]
50-
51-
# For invalid tokens, use 0 as a dummy index (will be masked out later)
52-
local_token_offsets_valid = torch.where(
53-
token_valid,
54-
local_token_offsets,
55-
0,
56-
) # [BLOCK_T]
57-
58-
# Convert local offsets to global sorted indices
59-
expert_sorted_token_indices = (
60-
start + local_token_offsets_valid
61-
) # [1, BLOCK_T]
62-
63-
# Map sorted indices back to global original token positions
64-
expert_orig_token_indices = sorted_to_orig_token_idx[
65-
expert_sorted_token_indices.squeeze(0)
66-
] # [BLOCK_T]
67-
68-
acc = hl.zeros([tile_t, tile_n], dtype=torch.float32)
69-
70-
# Perform tiled matrix multiplication: A[tokens, :] @ W[expert, :, :]
71-
for tile_k in hl.tile(K):
72-
A_frag = A[expert_orig_token_indices, tile_k] # [BLOCK_T, BLOCK_K]
73-
W_frag = W[e_idx, tile_k, tile_n] # [BLOCK_K, BLOCK_N]
74-
acc = torch.addmm(acc, A_frag, W_frag)
75-
76-
# Write results back to output tensor, masking out invalid tokens
77-
block_T = acc.size(0)
78-
block_N = acc.size(1)
79-
existing_values = C[expert_orig_token_indices, tile_n]
80-
mask_2d = token_valid.view(block_T, 1).expand(block_T, block_N)
81-
# Write results only for valid tokens, preserve existing values for invalid ones
82-
C[expert_orig_token_indices, tile_n] = torch.where(
83-
mask_2d, acc.to(C.dtype), existing_values
84-
)
42+
# Skip experts with no assigned tokens
43+
if num_tokens != 0:
44+
# Tile over tokens and output features for this expert
45+
for tile_t, tile_n in hl.tile([max_T_per_expert, N]):
46+
# Get local token offsets for this tile
47+
# (i.e. the tile's corresponding chunk in [0 .. max_T_per_expert-1] token range)
48+
local_token_offsets = tile_t.index # [BLOCK_T]
49+
50+
# Create mask for valid tokens (some tiles may be partially filled)
51+
token_valid = local_token_offsets < num_tokens # bool[BLOCK_T]
52+
53+
# For invalid tokens, use 0 as a dummy index (will be masked out later)
54+
local_token_offsets_valid = torch.where(
55+
token_valid,
56+
local_token_offsets,
57+
0,
58+
) # [BLOCK_T]
59+
60+
# Convert local offsets to global sorted indices
61+
expert_sorted_token_indices = (
62+
start + local_token_offsets_valid
63+
) # [1, BLOCK_T]
64+
65+
# Map sorted indices back to global original token positions
66+
expert_orig_token_indices = sorted_to_orig_token_idx[
67+
expert_sorted_token_indices.squeeze(0)
68+
] # [BLOCK_T]
69+
70+
acc = hl.zeros([tile_t, tile_n], dtype=torch.float32)
71+
72+
# Perform tiled matrix multiplication: A[tokens, :] @ W[expert, :, :]
73+
for tile_k in hl.tile(K):
74+
A_frag = A[expert_orig_token_indices, tile_k] # [BLOCK_T, BLOCK_K]
75+
W_frag = W[e_idx, tile_k, tile_n] # [BLOCK_K, BLOCK_N]
76+
acc = torch.addmm(acc, A_frag, W_frag)
77+
78+
# Write results back to output tensor, masking out invalid tokens
79+
block_T = acc.size(0)
80+
block_N = acc.size(1)
81+
existing_values = C[expert_orig_token_indices, tile_n]
82+
mask_2d = token_valid.view(block_T, 1).expand(block_T, block_N)
83+
# Write results only for valid tokens, preserve existing values for invalid ones
84+
C[expert_orig_token_indices, tile_n] = torch.where(
85+
mask_2d, acc.to(C.dtype), existing_values
86+
)
8587

8688
return C
8789

test/test_examples.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from expecttest import TestCase
77
from packaging import version
88
import torch
9+
from torch._environment import is_fbcode
910

1011
from helion._testing import DEVICE
1112
from helion._testing import code_and_output
@@ -1429,11 +1430,11 @@ def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch
14291430
return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=4)""",
14301431
)
14311432

1432-
@unittest.skip("TODO(yf225): fix occasional numerical error")
14331433
@unittest.skipIf(
14341434
torch.cuda.get_device_capability(0) < (9, 0),
14351435
"Triton internal error on RTX 3090",
14361436
)
1437+
@unittest.skipIf(is_fbcode(), "Triton internal error on fbcode Triton pin")
14371438
def test_moe_matmul_ogs(self):
14381439
mod = import_path(examples_dir / "moe_matmul_ogs.py")
14391440

@@ -1470,38 +1471,43 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
14701471
indices_0 = offset_0 + tl.zeros([1], tl.int32)
14711472
start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None)
14721473
num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
1473-
for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):
1474-
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1475-
mask_1 = indices_1 < max_T_per_expert
1476-
for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2):
1477-
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1478-
mask_2 = indices_2 < N
1479-
num_tokens_copy = num_tokens
1480-
start_copy = start
1481-
v_0 = num_tokens_copy[None]
1482-
v_1 = indices_1 < v_0
1483-
v_2 = tl.full([], 0, tl.int32)
1484-
v_3 = v_2[None]
1485-
v_4 = tl.where(v_1, indices_1, v_3)
1486-
v_5 = start_copy[None]
1487-
v_6 = v_5 + v_4
1488-
squeeze = tl.reshape(v_6, [_BLOCK_SIZE_1])
1489-
expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0)
1490-
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1491-
for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3):
1492-
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1493-
mask_3 = indices_3 < K
1494-
expert_orig_token_indices_copy = expert_orig_token_indices
1495-
acc_copy = acc
1496-
A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1497-
W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1498-
acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32')
1499-
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
1500-
view = tl.reshape(v_1, [_BLOCK_SIZE_1, 1])
1501-
mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2])
1502-
v_7 = acc.to(tl.float16)
1503-
v_8 = tl.where(mask_2d, v_7, existing_values)
1504-
tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_8, mask_1[:, None] & mask_2[None, :])
1474+
v_0 = tl.full([], 0, tl.int32)
1475+
v_1 = num_tokens != v_0
1476+
if v_1:
1477+
num_tokens_copy = num_tokens
1478+
start_copy = start
1479+
for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):
1480+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1481+
mask_1 = indices_1 < max_T_per_expert
1482+
for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2):
1483+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1484+
mask_2 = indices_2 < N
1485+
num_tokens_copy_copy = num_tokens_copy
1486+
start_copy_copy = start_copy
1487+
v_2 = num_tokens_copy_copy[None]
1488+
v_3 = indices_1 < v_2
1489+
v_4 = tl.full([], 0, tl.int32)
1490+
v_5 = v_4[None]
1491+
v_6 = tl.where(v_3, indices_1, v_5)
1492+
v_7 = start_copy_copy[None]
1493+
v_8 = v_7 + v_6
1494+
squeeze = tl.reshape(v_8, [_BLOCK_SIZE_1])
1495+
expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0)
1496+
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1497+
for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3):
1498+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1499+
mask_3 = indices_3 < K
1500+
expert_orig_token_indices_copy = expert_orig_token_indices
1501+
acc_copy = acc
1502+
A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1503+
W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1504+
acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32')
1505+
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
1506+
view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])
1507+
mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2])
1508+
v_9 = acc.to(tl.float16)
1509+
v_10 = tl.where(mask_2d, v_9, existing_values)
1510+
tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])
15051511
15061512
def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.Tensor, expert_token_offsets: torch.Tensor, sorted_to_orig_token_idx: torch.Tensor, max_T_per_expert_tensor: torch.Tensor):
15071513
T, K = A.shape

0 commit comments

Comments
 (0)