@@ -22,13 +22,17 @@ def pplx_hidden_dim_scale_bytes(
22
22
# ceil_div(hidden_dim, block_size) * sizeof(float32)
23
23
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
24
24
if quant_dtype is not None and quant_dtype .itemsize == 1 :
25
- block_size = block_shape [0 ] if block_shape is not None else 128
26
25
hidden_dim_bytes = hidden_dim * quant_dtype .itemsize
27
- if per_act_token_quant :
28
- hidden_scale_bytes = 4 * torch .float32 .itemsize #?
29
- else :
26
+ elem_size = torch .float32 .itemsize
27
+ if block_shape is not None :
28
+ assert not per_act_token_quant
29
+ block_size = block_shape [1 ]
30
30
hidden_scale_bytes = round_up (
31
- (cdiv (hidden_dim , block_size ) * torch .float32 .itemsize ), 16 )
31
+ (cdiv (hidden_dim , block_size ) * elem_size ), elem_size )
32
+ elif per_act_token_quant :
33
+ hidden_scale_bytes = hidden_dim * elem_size
34
+ else :
35
+ hidden_scale_bytes = 4 * elem_size
32
36
else :
33
37
hidden_dim_bytes = hidden_dim * in_dtype .itemsize
34
38
hidden_scale_bytes = 0
@@ -101,25 +105,21 @@ def prepare(
101
105
a1 , (None if self .per_act_token_quant else a1_scale ), self .quant_dtype ,
102
106
self .per_act_token_quant , self .block_shape )
103
107
104
- # pplx requires 2-d scales even for scalars
105
108
if a1q_scale is not None :
109
+ scalar_scales = a1q_scale .numel () == 1
110
+
111
+ # pplx requires 2-d scales even for scalar scales
106
112
if a1q_scale .dim () <= 1 :
107
- assert a1q_scale . numel () == 1
113
+ assert scalar_scales
108
114
a1q_scale = a1q_scale .view (1 , 1 )
109
115
110
- #print(f"ORIG {a1q_scale.shape}, {a1q_scale}")
111
-
112
- orig_scale = a1q_scale
113
- orig_a1q_scale_shape = a1q_scale .shape
116
+ # pad out scales if needed. TODO (bnell): do for non-scalar scales?
117
+ if scalar_scales :
118
+ a1q_scale = a1q_scale .repeat (a1q .shape [1 ], torch .float32 .itemsize )
114
119
115
- # pad out scales if needed
116
- if a1q_scale .numel () == 1 :
117
- a1q_scale = a1q_scale .repeat (a1q .shape [1 ], 4 )
118
-
119
- assert a1q_scale .shape [0 ] == a1q .shape [1 ]
120
-
121
- #print(f"FINAL {a1q_scale.shape}, {a1q_scale}")
120
+ orig_a_scale_block_shape = a1q_scale .shape [- 1 ]
122
121
122
+ #assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}"
123
123
124
124
assert a1q_scale is None or a1q_scale .ndim == 2 , \
125
125
f"{ 0 if a1q_scale is None else (a1q_scale .ndim , a1q_scale .shape )} "
@@ -146,26 +146,20 @@ def prepare(
146
146
expert_x_scale : Optional [torch .Tensor ] = None
147
147
if a1q .dtype .itemsize == 1 :
148
148
float32_size = torch .float32 .itemsize
149
- block_size = (self .block_shape [0 ] if self .block_shape is not None
150
- else 1 ) * float32_size
149
+ block_size = (self .block_shape [1 ] if self .block_shape is not None else 1 ) * float32_size
151
150
152
151
expert_x_scale_shape = (
153
152
num_local_experts ,
154
153
expert_x .size (1 ),
155
- #(expert_x.size(2) + block_size - 1) // block_size,
156
- orig_a1q_scale_shape [- 1 ],
154
+ (expert_x .size (2 ) + block_size - 1 ) // block_size if not scalar_scales else 1 ,
157
155
)
158
156
159
- #print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}")
160
-
161
157
expert_x_scale = torch .zeros (
162
158
expert_x_scale_shape ,
163
159
dtype = torch .float32 ,
164
160
device = expert_x .device ,
165
161
)
166
162
167
- #print(f"YYYYYYYYYYYYYYY {expert_x.shape}")
168
-
169
163
# This argument is optional, defaults to indices.size(0)
170
164
# There's not much point setting this unless it is != indices.size(0)
171
165
bound_m : Optional [torch .Tensor ] = None
@@ -182,22 +176,9 @@ def prepare(
182
176
if expert_x_scale is not None :
183
177
expert_x_scale = expert_x_scale [:, :, 0 :1 ]
184
178
185
- #print(f"ZZZZZZZZZZZZZZ {expert_x_scale.shape}")
186
179
if expert_x_scale is not None :
187
- expert_x_scale = expert_x_scale [:, :, :orig_a1q_scale_shape [- 1 ]]
188
- from math import prod
189
- if prod (orig_a1q_scale_shape ) == 1 :
190
- expert_x_scale = expert_x_scale [:, :1 , :1 ]
191
- #print(f"EPT {expert_num_tokens.flatten()}")
192
- #print(f"SCALARIZING!!! {expert_x_scale.shape}, {expert_x_scale.flatten()}")
193
- idx = expert_num_tokens .flatten () != 0
194
- assert torch .all (expert_x_scale .flatten ()[idx ] != 0 )
195
- #zidx = expert_num_tokens.flatten() == 0
196
- #assert torch.all(expert_x_scale.flatten()[zidx] == 0)
197
- assert expert_x_scale .ndim == 3
198
- #expert_x_scale = orig_scale.view(1)
199
-
200
- assert expert_x_scale .ndim == 1 or expert_x_scale .ndim == 3
180
+ expert_x_scale = expert_x_scale [:, :, :orig_a_scale_block_shape ]
181
+ assert expert_x_scale .ndim == 3
201
182
202
183
return expert_x , expert_x_scale , expert_num_tokens , None , None
203
184
0 commit comments