@@ -33,26 +33,26 @@ def fused_moe_kernel(
33
33
expert_ids_ptr ,
34
34
num_tokens_post_padded_ptr ,
35
35
# Matrix dimensions
36
- N ,
37
- K ,
38
- EM ,
39
- num_valid_tokens ,
36
+ N : tl . int64 ,
37
+ K : tl . int64 ,
38
+ EM : tl . int64 ,
39
+ num_valid_tokens : tl . int64 ,
40
40
# The stride variables represent how much to increase the ptr by when
41
41
# moving by 1 element in a particular dimension. E.g. `stride_am` is
42
42
# how much to increase `a_ptr` by to get the element one row down
43
43
# (A has M rows).
44
- stride_am ,
45
- stride_ak ,
46
- stride_be ,
47
- stride_bk ,
48
- stride_bn ,
49
- stride_cm ,
50
- stride_cn ,
51
- stride_asm ,
52
- stride_ask ,
53
- stride_bse ,
54
- stride_bsk ,
55
- stride_bsn ,
44
+ stride_am : tl . int64 ,
45
+ stride_ak : tl . int64 ,
46
+ stride_be : tl . int64 ,
47
+ stride_bk : tl . int64 ,
48
+ stride_bn : tl . int64 ,
49
+ stride_cm : tl . int64 ,
50
+ stride_cn : tl . int64 ,
51
+ stride_asm : tl . int64 ,
52
+ stride_ask : tl . int64 ,
53
+ stride_bse : tl . int64 ,
54
+ stride_bsk : tl . int64 ,
55
+ stride_bsn : tl . int64 ,
56
56
# Block size for block-wise quantization
57
57
group_n : tl .constexpr ,
58
58
group_k : tl .constexpr ,
@@ -114,18 +114,16 @@ def fused_moe_kernel(
114
114
num_tokens_post_padded = tl .load (num_tokens_post_padded_ptr )
115
115
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded :
116
116
return
117
- offs_token_id = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M ).to (
118
- tl .int64 )
117
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
119
118
offs_token = tl .load (sorted_token_ids_ptr + offs_token_id )
120
119
token_mask = offs_token < num_valid_tokens
121
120
122
- offs_bn = (pid_n * BLOCK_SIZE_N +
123
- tl .arange (0 , BLOCK_SIZE_N ).to (tl .int64 )) % N
121
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
124
122
offs_k = tl .arange (0 , BLOCK_SIZE_K )
125
123
a_ptrs = a_ptr + (offs_token [:, None ] // top_k * stride_am +
126
124
offs_k [None , :] * stride_ak )
127
125
128
- off_experts = tl .load (expert_ids_ptr + pid_m ). to ( tl . int64 )
126
+ off_experts = tl .load (expert_ids_ptr + pid_m )
129
127
b_ptrs = b_ptr + off_experts * stride_be + (offs_k [:, None ] * stride_bk +
130
128
offs_bn [None , :] * stride_bn )
131
129
if use_int8_w8a16 :
0 commit comments