|
15 | 15 |
|
16 | 16 | @triton.jit
|
17 | 17 | def moe_mmk(
|
18 |
| - a_ptrs, |
19 |
| - b_ptrs, |
20 |
| - K, |
21 |
| - expert_id, |
22 |
| - a_scale_ptr, |
23 |
| - b_scale_ptr, |
24 |
| - # The stride variables represent how much to increase the ptr by when |
25 |
| - # moving by 1 element in a particular dimension. E.g. `stride_am` is |
26 |
| - # how much to increase `a_ptr` by to get the element one row down |
27 |
| - # (A has M rows). |
28 |
| - stride_ak, |
29 |
| - stride_bk, |
30 |
| - stride_asm, |
31 |
| - stride_ask, |
32 |
| - stride_bse, |
33 |
| - stride_bsk, |
34 |
| - stride_bsn, |
35 |
| - # Offsets and masks |
36 |
| - offs_m, |
37 |
| - offs_n, |
38 |
| - mask_m, |
39 |
| - # Block size for block-wise quantization |
40 |
| - group_n: tl.constexpr, |
41 |
| - group_k: tl.constexpr, |
42 |
| - # Meta-parameters |
43 |
| - BLOCK_M: tl.constexpr, |
44 |
| - BLOCK_N: tl.constexpr, |
45 |
| - BLOCK_K: tl.constexpr, |
46 |
| - compute_type: tl.constexpr, |
47 |
| - use_w8a8: tl.constexpr, |
48 |
| - use_w8a16: tl.constexpr): |
49 |
| - |
| 18 | + a_ptrs, |
| 19 | + b_ptrs, |
| 20 | + K, |
| 21 | + expert_id, |
| 22 | + a_scale_ptr, |
| 23 | + b_scale_ptr, |
| 24 | + # The stride variables represent how much to increase the ptr by when |
| 25 | + # moving by 1 element in a particular dimension. E.g. `stride_am` is |
| 26 | + # how much to increase `a_ptr` by to get the element one row down |
| 27 | + # (A has M rows). |
| 28 | + stride_ak, |
| 29 | + stride_bk, |
| 30 | + stride_asm, |
| 31 | + stride_ask, |
| 32 | + stride_bse, |
| 33 | + stride_bsk, |
| 34 | + stride_bsn, |
| 35 | + # Offsets and masks |
| 36 | + offs_m, |
| 37 | + offs_n, |
| 38 | + mask_m, |
| 39 | + # Block size for block-wise quantization |
| 40 | + group_n: tl.constexpr, |
| 41 | + group_k: tl.constexpr, |
| 42 | + # Meta-parameters |
| 43 | + BLOCK_M: tl.constexpr, |
| 44 | + BLOCK_N: tl.constexpr, |
| 45 | + BLOCK_K: tl.constexpr, |
| 46 | + compute_type: tl.constexpr, |
| 47 | + use_w8a8: tl.constexpr, |
| 48 | + use_w8a16: tl.constexpr |
| 49 | +): |
50 | 50 | offs_k = tl.arange(0, BLOCK_K)
|
51 | 51 |
|
52 | 52 | if use_w8a16:
|
@@ -310,22 +310,22 @@ def batched_triton_kernel(
|
310 | 310 |
|
311 | 311 |
|
312 | 312 | def invoke_moe_batched_triton_kernel(
|
313 |
| - A: torch.Tensor, # [E, max_tokens, K] |
314 |
| - B: torch.Tensor, # [E, K, N] |
315 |
| - C: torch.Tensor, # [E, max_tokens, N] |
316 |
| - expert_num_tokens: torch.Tensor, # [E] |
317 |
| - compute_type: tl.dtype, |
318 |
| - # Quantization data |
319 |
| - A_scale: Optional[torch.Tensor], |
320 |
| - B_scale: Optional[torch.Tensor], |
321 |
| - B_zp: torch.Tensor, |
322 |
| - # Quantization schemes |
323 |
| - use_fp8_w8a8: bool, |
324 |
| - use_int8_w8a16: bool, |
325 |
| - use_int4_w4a16: bool, |
326 |
| - config: dict[str, int], |
327 |
| - block_shape: Optional[list[int]] = None): |
328 |
| - |
| 313 | + A: torch.Tensor, # [E, max_tokens, K] |
| 314 | + B: torch.Tensor, # [E, K, N] |
| 315 | + C: torch.Tensor, # [E, max_tokens, N] |
| 316 | + expert_num_tokens: torch.Tensor, # [E] |
| 317 | + compute_type: tl.dtype, |
| 318 | + # Quantization data |
| 319 | + A_scale: Optional[torch.Tensor], |
| 320 | + B_scale: Optional[torch.Tensor], |
| 321 | + B_zp: torch.Tensor, |
| 322 | + # Quantization schemes |
| 323 | + use_fp8_w8a8: bool, |
| 324 | + use_int8_w8a16: bool, |
| 325 | + use_int4_w4a16: bool, |
| 326 | + config: dict[str, int], |
| 327 | + block_shape: Optional[list[int]] = None |
| 328 | +): |
329 | 329 | assert not use_int4_w4a16
|
330 | 330 | max_num_tokens = A.size(1)
|
331 | 331 | K = A.size(2)
|
|
0 commit comments