Skip to content

Commit addd937

Browse files
committed
wip hacking
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 31b66d8 commit addd937

File tree

3 files changed

+33
-42
lines changed

3 files changed

+33
-42
lines changed

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,8 @@ def apply(
922922

923923
intermediate_cache1.fill_(0)
924924

925+
#print(f"A1_SCALES {a1q_scale.shape}")
926+
925927
# MM1
926928
invoke_moe_batched_triton_kernel(A=hidden_states,
927929
B=w1,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,16 @@ def init_prepare_finalize(self, moe: MoEConfig,
340340

341341
input_activations = get_quant_config_input_activations(
342342
quant_config)
343+
block_shape = quant_config.weight_block_size if quant_config is not None else None
343344

344345
logger.debug("PplxPrepareAndFinalize")
346+
347+
# XXXXXXXXXXXXXXXXXXXXXXXXX TODO
348+
# Remove quant flags from PrepareAndFinalize ctor and
349+
# pass them in as arguments to prepare(). Get them
350+
# from the FusedExperts as attributes or arguments
351+
352+
345353
prepare_finalize = PplxPrepareAndFinalize(
346354
handle,
347355
max_num_tokens=moe.max_num_tokens,
@@ -353,7 +361,7 @@ def init_prepare_finalize(self, moe: MoEConfig,
353361
per_act_token=(input_activations.strategy
354362
== QuantizationStrategy.TOKEN
355363
if input_activations is not None else False),
356-
block_shape=None, # TODO (bnell): quantization
364+
block_shape=None, #block_shape
357365
)
358366
elif moe.use_deepep_ht_kernels:
359367
assert moe.dp_size == all2all_manager.dp_world_size

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ def pplx_hidden_dim_scale_bytes(
2222
# ceil_div(hidden_dim, block_size) * sizeof(float32)
2323
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
2424
if quant_dtype is not None and quant_dtype.itemsize == 1:
25-
block_size = block_shape[0] if block_shape is not None else 128
2625
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
27-
if per_act_token_quant:
28-
hidden_scale_bytes = 4 * torch.float32.itemsize #?
29-
else:
26+
elem_size = torch.float32.itemsize
27+
if block_shape is not None:
28+
assert not per_act_token_quant
29+
block_size = block_shape[1]
3030
hidden_scale_bytes = round_up(
31-
(cdiv(hidden_dim, block_size) * torch.float32.itemsize), 16)
31+
(cdiv(hidden_dim, block_size) * elem_size), elem_size)
32+
elif per_act_token_quant:
33+
hidden_scale_bytes = hidden_dim * elem_size
34+
else:
35+
hidden_scale_bytes = 4 * elem_size
3236
else:
3337
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
3438
hidden_scale_bytes = 0
@@ -101,25 +105,21 @@ def prepare(
101105
a1, (None if self.per_act_token_quant else a1_scale), self.quant_dtype,
102106
self.per_act_token_quant, self.block_shape)
103107

104-
# pplx requires 2-d scales even for scalars
105108
if a1q_scale is not None:
109+
scalar_scales = a1q_scale.numel() == 1
110+
111+
# pplx requires 2-d scales even for scalar scales
106112
if a1q_scale.dim() <= 1:
107-
assert a1q_scale.numel() == 1
113+
assert scalar_scales
108114
a1q_scale = a1q_scale.view(1, 1)
109115

110-
#print(f"ORIG {a1q_scale.shape}, {a1q_scale}")
111-
112-
orig_scale = a1q_scale
113-
orig_a1q_scale_shape = a1q_scale.shape
116+
# pad out scales if needed. TODO (bnell): do for non-scalar scales?
117+
if scalar_scales:
118+
a1q_scale = a1q_scale.repeat(a1q.shape[1], torch.float32.itemsize)
114119

115-
# pad out scales if needed
116-
if a1q_scale.numel() == 1:
117-
a1q_scale = a1q_scale.repeat(a1q.shape[1], 4)
118-
119-
assert a1q_scale.shape[0] == a1q.shape[1]
120-
121-
#print(f"FINAL {a1q_scale.shape}, {a1q_scale}")
120+
orig_a_scale_block_shape = a1q_scale.shape[-1]
122121

122+
#assert a1_scale is None or a1_scale.shape[0] == a1q.shape[1], f"{a1_scale.shape}, {a1q_scale.shape}"
123123

124124
assert a1q_scale is None or a1q_scale.ndim == 2, \
125125
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
@@ -146,26 +146,20 @@ def prepare(
146146
expert_x_scale: Optional[torch.Tensor] = None
147147
if a1q.dtype.itemsize == 1:
148148
float32_size = torch.float32.itemsize
149-
block_size = (self.block_shape[0] if self.block_shape is not None
150-
else 1) * float32_size
149+
block_size = (self.block_shape[1] if self.block_shape is not None else 1) * float32_size
151150

152151
expert_x_scale_shape = (
153152
num_local_experts,
154153
expert_x.size(1),
155-
#(expert_x.size(2) + block_size - 1) // block_size,
156-
orig_a1q_scale_shape[-1],
154+
(expert_x.size(2) + block_size - 1) // block_size if not scalar_scales else 1,
157155
)
158156

159-
#print(f"XXXXXXXXXX {block_size} {expert_x_scale_shape}")
160-
161157
expert_x_scale = torch.zeros(
162158
expert_x_scale_shape,
163159
dtype=torch.float32,
164160
device=expert_x.device,
165161
)
166162

167-
#print(f"YYYYYYYYYYYYYYY {expert_x.shape}")
168-
169163
# This argument is optional, defaults to indices.size(0)
170164
# There's not much point setting this unless it is != indices.size(0)
171165
bound_m: Optional[torch.Tensor] = None
@@ -182,22 +176,9 @@ def prepare(
182176
if expert_x_scale is not None:
183177
expert_x_scale = expert_x_scale[:, :, 0:1]
184178

185-
#print(f"ZZZZZZZZZZZZZZ {expert_x_scale.shape}")
186179
if expert_x_scale is not None:
187-
expert_x_scale = expert_x_scale[:, :, :orig_a1q_scale_shape[-1]]
188-
from math import prod
189-
if prod(orig_a1q_scale_shape) == 1:
190-
expert_x_scale = expert_x_scale[:, :1, :1]
191-
#print(f"EPT {expert_num_tokens.flatten()}")
192-
#print(f"SCALARIZING!!! {expert_x_scale.shape}, {expert_x_scale.flatten()}")
193-
idx = expert_num_tokens.flatten() != 0
194-
assert torch.all(expert_x_scale.flatten()[idx] != 0)
195-
#zidx = expert_num_tokens.flatten() == 0
196-
#assert torch.all(expert_x_scale.flatten()[zidx] == 0)
197-
assert expert_x_scale.ndim == 3
198-
#expert_x_scale = orig_scale.view(1)
199-
200-
assert expert_x_scale.ndim == 1 or expert_x_scale.ndim == 3
180+
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
181+
assert expert_x_scale.ndim == 3
201182

202183
return expert_x, expert_x_scale, expert_num_tokens, None, None
203184

0 commit comments

Comments
 (0)