|
6 | 6 | from expecttest import TestCase
|
7 | 7 | from packaging import version
|
8 | 8 | import torch
|
| 9 | +from torch._environment import is_fbcode |
9 | 10 |
|
10 | 11 | from helion._testing import DEVICE
|
11 | 12 | from helion._testing import code_and_output
|
@@ -1429,11 +1430,11 @@ def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch
|
1429 | 1430 | return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=4)""",
|
1430 | 1431 | )
|
1431 | 1432 |
|
1432 |
| - @unittest.skip("TODO(yf225): fix occasional numerical error") |
1433 | 1433 | @unittest.skipIf(
|
1434 | 1434 | torch.cuda.get_device_capability(0) < (9, 0),
|
1435 | 1435 | "Triton internal error on RTX 3090",
|
1436 | 1436 | )
|
| 1437 | + @unittest.skipIf(is_fbcode(), "Triton internal error on fbcode Triton pin") |
1437 | 1438 | def test_moe_matmul_ogs(self):
|
1438 | 1439 | mod = import_path(examples_dir / "moe_matmul_ogs.py")
|
1439 | 1440 |
|
@@ -1470,38 +1471,43 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
|
1470 | 1471 | indices_0 = offset_0 + tl.zeros([1], tl.int32)
|
1471 | 1472 | start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None)
|
1472 | 1473 | num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
|
1473 |
| - for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1): |
1474 |
| - indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) |
1475 |
| - mask_1 = indices_1 < max_T_per_expert |
1476 |
| - for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2): |
1477 |
| - indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) |
1478 |
| - mask_2 = indices_2 < N |
1479 |
| - num_tokens_copy = num_tokens |
1480 |
| - start_copy = start |
1481 |
| - v_0 = num_tokens_copy[None] |
1482 |
| - v_1 = indices_1 < v_0 |
1483 |
| - v_2 = tl.full([], 0, tl.int32) |
1484 |
| - v_3 = v_2[None] |
1485 |
| - v_4 = tl.where(v_1, indices_1, v_3) |
1486 |
| - v_5 = start_copy[None] |
1487 |
| - v_6 = v_5 + v_4 |
1488 |
| - squeeze = tl.reshape(v_6, [_BLOCK_SIZE_1]) |
1489 |
| - expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0) |
1490 |
| - acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) |
1491 |
| - for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3): |
1492 |
| - indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) |
1493 |
| - mask_3 = indices_3 < K |
1494 |
| - expert_orig_token_indices_copy = expert_orig_token_indices |
1495 |
| - acc_copy = acc |
1496 |
| - A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0) |
1497 |
| - W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0) |
1498 |
| - acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32') |
1499 |
| - existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0) |
1500 |
| - view = tl.reshape(v_1, [_BLOCK_SIZE_1, 1]) |
1501 |
| - mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2]) |
1502 |
| - v_7 = acc.to(tl.float16) |
1503 |
| - v_8 = tl.where(mask_2d, v_7, existing_values) |
1504 |
| - tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_8, mask_1[:, None] & mask_2[None, :]) |
| 1474 | + v_0 = tl.full([], 0, tl.int32) |
| 1475 | + v_1 = num_tokens != v_0 |
| 1476 | + if v_1: |
| 1477 | + num_tokens_copy = num_tokens |
| 1478 | + start_copy = start |
| 1479 | + for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1): |
| 1480 | + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) |
| 1481 | + mask_1 = indices_1 < max_T_per_expert |
| 1482 | + for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2): |
| 1483 | + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) |
| 1484 | + mask_2 = indices_2 < N |
| 1485 | + num_tokens_copy_copy = num_tokens_copy |
| 1486 | + start_copy_copy = start_copy |
| 1487 | + v_2 = num_tokens_copy_copy[None] |
| 1488 | + v_3 = indices_1 < v_2 |
| 1489 | + v_4 = tl.full([], 0, tl.int32) |
| 1490 | + v_5 = v_4[None] |
| 1491 | + v_6 = tl.where(v_3, indices_1, v_5) |
| 1492 | + v_7 = start_copy_copy[None] |
| 1493 | + v_8 = v_7 + v_6 |
| 1494 | + squeeze = tl.reshape(v_8, [_BLOCK_SIZE_1]) |
| 1495 | + expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0) |
| 1496 | + acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) |
| 1497 | + for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3): |
| 1498 | + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) |
| 1499 | + mask_3 = indices_3 < K |
| 1500 | + expert_orig_token_indices_copy = expert_orig_token_indices |
| 1501 | + acc_copy = acc |
| 1502 | + A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0) |
| 1503 | + W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0) |
| 1504 | + acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32') |
| 1505 | + existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0) |
| 1506 | + view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1]) |
| 1507 | + mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2]) |
| 1508 | + v_9 = acc.to(tl.float16) |
| 1509 | + v_10 = tl.where(mask_2d, v_9, existing_values) |
| 1510 | + tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :]) |
1505 | 1511 |
|
1506 | 1512 | def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.Tensor, expert_token_offsets: torch.Tensor, sorted_to_orig_token_idx: torch.Tensor, max_T_per_expert_tensor: torch.Tensor):
|
1507 | 1513 | T, K = A.shape
|
|
0 commit comments