@@ -34,7 +34,7 @@ def pplx_hidden_dim_scale_bytes(
34
34
if per_act_token_quant :
35
35
# per-token
36
36
assert block_shape is None
37
- hidden_scale_bytes = max_num_tokens * elem_size
37
+ hidden_scale_bytes = elem_size
38
38
elif block_shape is not None :
39
39
# per-group
40
40
block_size = block_shape [1 ]
@@ -47,8 +47,10 @@ def pplx_hidden_dim_scale_bytes(
47
47
hidden_dim_bytes = hidden_dim * in_dtype .itemsize
48
48
hidden_scale_bytes = 0
49
49
50
- return round_up (hidden_dim_bytes , align ), round_up (hidden_scale_bytes ,
51
- align )
50
+ return (
51
+ round_up (hidden_dim_bytes , align ),
52
+ round_up (hidden_scale_bytes , align ),
53
+ )
52
54
53
55
54
56
# The max_num_tokens, world_size and dp_size must be the same
@@ -111,7 +113,7 @@ def prepare(
111
113
a1 = a1 * topk_weights .to (a1 .dtype )
112
114
113
115
repeat_cols = 4
114
- repeat_rows = 1 if quant_config .per_act_token_quant else a1 .shape [ 0 ]
116
+ repeat_rows = 1 if quant_config .per_act_token_quant else a1 .size ( 0 )
115
117
a1q , a1q_scale = moe_kernel_quantize_input (
116
118
a1 , (None if quant_config .per_act_token_quant else a1_scale ),
117
119
quant_dtype = quant_config .quant_dtype ,
@@ -146,16 +148,12 @@ def prepare(
146
148
147
149
expert_x_scale : Optional [torch .Tensor ] = None
148
150
if a1q .dtype .itemsize == 1 :
149
- float32_size = torch .float32 .itemsize
150
151
block_size = (quant_config .block_shape [1 ]
151
- if quant_config .block_shape is not None else
152
- float32_size )
152
+ if quant_config .block_shape is not None else 1 )
153
153
expert_x_scale = torch .empty (
154
- (
155
- num_local_experts ,
156
- expert_x .size (1 ),
157
- (expert_x .size (2 ) + block_size - 1 ) // block_size ,
158
- ),
154
+ (num_local_experts , expert_x .size (1 ),
155
+ round_up (
156
+ (expert_x .size (2 ) + block_size - 1 ) // block_size , 4 )),
159
157
dtype = torch .float32 ,
160
158
device = device ,
161
159
)
0 commit comments