Skip to content

Commit 22e19a4

Browse files
committed
Enabling MOE Quantization using linear decomposition [WIP]
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. current tests are running locally but will be added once working. currently int8wo and int8dq are working for multi and single token moe inference while int4wo is being finished up. TODO move test set into ao, move quantizable moe module code to ao test on hf model definition. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a81322e commit 22e19a4

File tree

7 files changed

+231
-43
lines changed

7 files changed

+231
-43
lines changed

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,46 @@ def _(func, types, args, kwargs):
477477
)
478478
return return_and_correct_aliasing(func, args, kwargs, new)
479479

480+
@implements(aten.index.Tensor)
481+
def _(func, types, args, kwargs):
482+
self, indices = args
483+
assert len(indices) == 1, f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}"
484+
485+
new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices)
486+
shape = tuple([indices[0].numel(), *self.shape[1:]])
487+
488+
block_size = self.block_size
489+
new = self.__class__(
490+
new_tensor_impl,
491+
block_size,
492+
shape,
493+
self.quant_min,
494+
self.quant_max,
495+
self.zero_point_domain,
496+
dtype=self.dtype,
497+
)
498+
return return_and_correct_aliasing(func, args, kwargs, new)
499+
500+
@implements(aten.select.int)
501+
def _(func, types, args, kwargs):
502+
self, dim, index = fill_defaults(args, 3, [0, 0])
503+
assert dim==0, f"op {func} currently only implemented for dim=0 but got dim={dim}"
504+
assert self.dim() == 3, f"op {func} currently only implemented for 3 dimensional tensors but got shape={self.shape}"
505+
506+
new_tensor_impl = aten.select.int(self.tensor_impl, dim, index)
507+
508+
shape = self.shape[1:]
509+
block_size = self.block_size[1:]
510+
new = self.__class__(
511+
new_tensor_impl,
512+
block_size,
513+
shape,
514+
self.quant_min,
515+
self.quant_max,
516+
self.zero_point_domain,
517+
dtype=self.dtype,
518+
)
519+
return return_and_correct_aliasing(func, args, kwargs, new)
480520

481521
# this is needed for DTensor.from_local() and for flattening tensor
482522
@implements(aten.view.default)

torchao/dtypes/uintx/plain_layout.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
154154
)
155155
return return_and_correct_aliasing(func, args, kwargs, new)
156156

157+
158+
elif func in [aten.select.int, aten.index.Tensor]:
159+
return return_and_correct_aliasing(
160+
func,
161+
args,
162+
kwargs,
163+
args[0]._apply_fn_to_data(
164+
lambda x: func(x, *args[1:], **kwargs)
165+
),
166+
)
167+
157168
elif func is aten.slice.Tensor:
158169
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
159170
if dim == 0:

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 79 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
7575
f"need input_tensor shape: {input_tensor.shape} final"
7676
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
7777
)
78-
7978
# TODO: check groupsize quantization
8079
# avoid circular dep, TODO: move this to a common util.py
8180
act_mat = input_tensor
@@ -97,7 +96,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
9796
y = torch.ops.aten._weight_int4pack_mm(
9897
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
9998
)
100-
10199
# remove out_feature padding
102100
orig_out_features = weight_tensor.shape[-2]
103101
y = y[:, :orig_out_features]
@@ -119,7 +117,7 @@ class TensorCoreTiledLayout(Layout):
119117
inner_k_tiles: int = 8
120118

121119
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
122-
orig_out_features, orig_in_features = input.shape
120+
orig_out_features, orig_in_features = input.shape[-2:]
123121
in_features = find_multiple(orig_in_features, 1024)
124122
out_features = find_multiple(orig_out_features, 8)
125123
input = torch.nn.functional.pad(
@@ -160,18 +158,18 @@ def post_process(
160158
zero_point: torch.Tensor,
161159
block_size: Tuple[int, ...],
162160
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
163-
orig_out_features, orig_in_features = input.shape
161+
orig_out_features, orig_in_features = input.shape[-2:]
164162
in_features = find_multiple(orig_in_features, 1024)
165163
out_features = find_multiple(orig_out_features, 8)
166164
input = torch.nn.functional.pad(
167165
input,
168166
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
169167
)
170168
assert (
171-
len(block_size) == 2
172-
), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}"
173-
scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0]
174-
scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1]
169+
len(block_size) == 2 or len(block_size) == 3,
170+
), f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3, got: {block_size}"
171+
scale_pad_dim_0 = (out_features - orig_out_features) // block_size[-2]
172+
scale_pad_dim_1 = (in_features - orig_in_features) // block_size[-1]
175173
scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0))
176174
zero_point = torch.nn.functional.pad(
177175
zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0)
@@ -272,11 +270,22 @@ def from_plain(
272270
assert (
273271
int_data.dtype == torch.int32
274272
), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
275-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
276-
int_data, _layout.inner_k_tiles
277-
)
278-
scale = scale.reshape(int_data.shape[0], -1)
279-
zero_point = zero_point.reshape(int_data.shape[0], -1)
273+
def quant_2d(mat):
274+
return torch.ops.aten._convert_weight_to_int4pack(
275+
mat, _layout.inner_k_tiles
276+
)
277+
if int_data.dim() == 3: # for moe quant
278+
num_experts = int_data.shape[0]
279+
packed_weight_list = []
280+
for expert in range(num_experts):
281+
packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0))
282+
packed_weight = torch.cat(packed_weight_list, dim=0)
283+
scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1)
284+
zero_point = zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1)
285+
else:
286+
packed_weight = quant_2d(int_data)
287+
scale = scale.reshape(int_data.shape[0], -1)
288+
zero_point = zero_point.reshape(int_data.shape[0], -1)
280289
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
281290

282291
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
@@ -336,6 +345,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
336345
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
337346
)
338347

348+
if func in [aten.select.int, aten.index.Tensor]:
349+
assert not (func is aten.select.int and args[1]!=0), "aten.select.int currently only has support for dim=0"
350+
return return_and_correct_aliasing(
351+
func,
352+
args,
353+
kwargs,
354+
args[0]._apply_fn_to_data(
355+
lambda x: func(x, *args[1:], **kwargs)
356+
),
357+
)
358+
359+
339360
if func is aten.t.default:
340361
"""we don't need to repack the weight and just rely on external
341362
shape being changed and record the status of transpose/no-transpose
@@ -386,11 +407,15 @@ def block_size(self):
386407

387408
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
388409
cur_shape = self.shape
389-
assert len(cur_shape) == 4
410+
if len(cur_shape) == 5:
411+
ones = [1,1]
412+
cur_shape = cur_shape[1:]
413+
elif len(cur_shape) == 4:
414+
ones = [1]
390415
inner_k_tiles = cur_shape[-1] * 2
391416
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
392417
groupsize = int(original_shape[1] / scale.shape[-2])
393-
return (1, groupsize)
418+
return tuple([*ones, groupsize])
394419

395420
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
396421
from torchao.quantization.quant_primitives import (
@@ -399,35 +424,54 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
399424
)
400425
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
401426

427+
def dequant_4d(self):
428+
cur_shape = self.shape
429+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
430+
assert len(cur_shape) == 4
431+
inner_k_tiles = cur_shape[-1] * 2
432+
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
433+
eye_shape = original_shape[1]
434+
groupsize = int(original_shape[1] / scale.shape[-2])
435+
block_size = (1, groupsize)
436+
original_dtype = torch.bfloat16
437+
assert len(block_size) == 2 and block_size[0] == 1
438+
dequantized = torch.ops.aten._weight_int4pack_mm(
439+
torch.eye(eye_shape, device=self.device, dtype=original_dtype),
440+
self.packed_weight,
441+
groupsize,
442+
self.scale_and_zero,
443+
)
444+
dequantized = dequantized.t().contiguous()
445+
return dequantized
446+
447+
cur_shape = self.shape
448+
449+
if len(cur_shape)==4:
450+
dequantized = dequant_4d(self)
451+
else:
452+
assert len(cur_shape) == 5
453+
num_experts = cur_shape[0]
454+
dequantized_list = []
455+
import fbvscode; fbvscode.set_trace()
456+
for expert in range(num_experts):
457+
dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0))
458+
dequantized = torch.cat(dequantized_list, dim=0)
459+
460+
402461
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
462+
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
463+
scale = scale.reshape(scale.shape[:-1]).contiguous()
464+
zero = zero.reshape(zero.shape[:-1]).contiguous()
403465

404-
cur_shape = self.shape
405-
assert len(cur_shape) == 4
406-
inner_k_tiles = cur_shape[-1] * 2
407-
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
408-
eye_shape = original_shape[1]
409-
groupsize = int(original_shape[1] / scale.shape[-2])
410-
block_size = (1, groupsize)
411466
device = self.device
412-
original_dtype = torch.bfloat16
467+
413468
target_dtype = torch.int32
414469
quant_min = 0
415470
quant_max = 15
416471
zero_point_domain = ZeroPointDomain.FLOAT
417-
assert len(block_size) == 2 and block_size[0] == 1
418-
dequantized = torch.ops.aten._weight_int4pack_mm(
419-
torch.eye(eye_shape, device=device, dtype=original_dtype),
420-
self.packed_weight,
421-
groupsize,
422-
self.scale_and_zero,
423-
)
424-
dequantized = dequantized.t().contiguous()
425-
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
426-
scale = scale.reshape(scale.shape[:-1]).contiguous()
427-
zero = zero.reshape(zero.shape[:-1]).contiguous()
428472
int_data = quantize_affine(
429473
dequantized,
430-
block_size,
474+
self.block_size,
431475
scale,
432476
zero,
433477
target_dtype,
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
class MOEFeedForwardAOQuantizable(nn.Module):
2+
def __init__(self, config) -> None:
3+
super().__init__()
4+
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
5+
self.cond_ffn = ConditionalFeedForwardAOQuantizable(config)
6+
self.dim = config.dim
7+
self.num_activated_experts = config.num_activated_experts
8+
def forward(self, x: Tensor) -> Tensor:
9+
batch_size = x.shape[0]
10+
x = x.view(-1, self.dim) # x: [T, D]
11+
scores = self.gate(x) # [T, E]
12+
expert_weights = F.softmax(scores, dim=-1)
13+
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
14+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A]
15+
out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts)
16+
return out.reshape(batch_size, -1, self.dim)
17+
18+
19+
class ConditionalFeedForwardAOQuantizable(nn.Module):
20+
def __init__(self, config):
21+
super().__init__()
22+
self.config = config
23+
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D
24+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I
25+
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D
26+
self.num_experts = config.num_experts
27+
def forward(
28+
self, x: Tensor, # T, D
29+
expert_indices: Tensor, # T, A
30+
expert_weights: Tensor, # T, A
31+
num_activated_experts: int,
32+
) -> Tensor:
33+
num_tokens, dim = x.shape
34+
num_token_activations = num_tokens * num_activated_experts
35+
36+
if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled)
37+
outs = []
38+
expert_indices=expert_indices.squeeze()
39+
# collect used experts
40+
w1 = self.w1[expert_indices]
41+
w2 = self.w2[expert_indices]
42+
w3 = self.w3[expert_indices]
43+
44+
# run token through each expert
45+
for index in range(num_activated_experts):
46+
cur_out = F.linear( F.silu(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index])
47+
outs.append(cur_out)
48+
49+
# combine outputs
50+
final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1)
51+
return final_out
52+
else:
53+
expert_list = [x for x in range(self.num_experts)]
54+
55+
# shuffle tokens into groups for each expert
56+
ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A]
57+
ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T]
58+
59+
num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing)
60+
cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1]
61+
62+
# without quant this is compilable, with quant it throws an error.
63+
# Even without quant there's a graph break here so not a huge loss
64+
@torch._dynamo.disable()
65+
def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list):
66+
token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ...
67+
return token_indices_per_expert
68+
token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list)
69+
tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert]
70+
71+
# calculate outputs for each expert
72+
outs = []
73+
for cur_x, expert in zip(tokens_grouped_by_expert,expert_list):
74+
75+
w1=self.w1[expert] # I, D
76+
w2=self.w2[expert] # D, I
77+
w3=self.w3[expert] # I, D
78+
79+
cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D]
80+
outs.append(cur_out)
81+
82+
# weigh outputs
83+
ordered_outs = torch.cat(outs, dim=0) # [T*A, D]
84+
ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1]
85+
weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D]
86+
87+
# sum weighted token-activation outputs together for each token
88+
final_out = torch.zeros_like(x) # [T, D]
89+
final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs)
90+
return final_out

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def _replace_with_custom_fn_if_matches_filter(
300300
device,
301301
extra_args,
302302
)
303-
if new_child is not child:
303+
if new_child is not child and new_child is not None:
304304
setattr(model, name, new_child)
305305
if device is not None:
306306
model.to(device=device) # move parent module to device

torchao/quantization/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -366,22 +366,23 @@ def get_groupwise_affine_qparams(
366366
def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16):
367367
guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
368368
guard_dtype_size(zeros, "zeros", dtype=dtype)
369+
dim = scales.dim()
369370
return (
370371
torch.cat(
371372
[
372-
scales.reshape(scales.size(0), scales.size(1), 1),
373-
zeros.reshape(zeros.size(0), zeros.size(1), 1),
373+
scales.unsqueeze(-1),
374+
zeros.unsqueeze(-1),
374375
],
375-
2,
376+
dim,
376377
)
377-
.transpose(0, 1)
378+
.transpose(-3, -2)
378379
.contiguous()
379380
)
380381

381382

382383
def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
383-
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
384-
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
384+
assert scales_and_zeros.shape[-1] == 2
385+
return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1)
385386

386387

387388
def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False):

0 commit comments

Comments
 (0)