Skip to content

Commit c024f5d

Browse files
committed
fixing CI
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 59b3fab commit c024f5d

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

test/quantization/test_moe_quant.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,25 @@ def test_int8wo_base(self, name, num_tokens, fullgraph):
169169
fullgraph=fullgraph,
170170
)
171171

172+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
173+
@parameterized.expand(
174+
[
175+
("single_token", 1, True),
176+
("multiple_tokens", 8, False),
177+
]
178+
)
179+
def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
180+
config = Int8WeightOnlyConfig()
181+
tensor_impl_class = PlainAQTTensorImpl
182+
183+
self._test_impl_moe_quant(
184+
config=config,
185+
num_tokens=num_tokens,
186+
tensor_impl_class=tensor_impl_class,
187+
fullgraph=fullgraph,
188+
device="cpu",
189+
)
190+
172191
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
173192
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "Test only enabled for 2.5+")
174193
@parameterized.expand(

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,8 @@ def test_moe_quant_intx(self):
646646
from torchao.quantization.utils import compute_error
647647

648648
with torch.device("cpu"):
649-
model = MOEFeedForwardAOQuantizable(512, 256, 8, 2).to(torch.bfloat16)
650-
x = torch.randn(1, 512, dtype=torch.bfloat16)
649+
model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to(torch.bfloat16)
650+
x = torch.randn(8, 512, dtype=torch.bfloat16)
651651

652652
out = model(x).clone()
653653

@@ -661,7 +661,15 @@ def test_moe_quant_intx(self):
661661
out_q = model(x).clone()
662662
assert isinstance(model.experts.w1, FakeExtraDimTensor)
663663

664-
assert compute_error(out_q, out) > 30, "error bad accuracy but everything ran"
664+
mod_c = torch.compile(model, mode="reduce-overhead")
665+
666+
mod_c(x)
667+
mod_c(x)
668+
669+
out_qc = mod_c(x).clone()
670+
671+
self.assertGreater(compute_error(out_q, out), 30)
672+
self.assertGreater(compute_error(out_qc, out), 30)
665673

666674

667675
if __name__ == "__main__":

torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,16 @@ def forward(
120120
ordered_token_indices = (
121121
ordered_token_activations.div(top_k).floor().to(torch.int64)
122122
) # [T]
123-
num_tokens_per_expert = torch.histc(
124-
expert_indices,
125-
bins=self.num_experts + 1,
126-
min=-1,
127-
max=self.num_experts,
128-
) # [E+1] (added leading 0 so can be used for indexing)
123+
if not expert_indices.is_cuda: # histc doesn't work on cpu for integers
124+
num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1)
125+
else:
126+
num_tokens_per_expert = torch.histc(
127+
expert_indices,
128+
bins=self.num_experts + 1,
129+
min=-1,
130+
max=self.num_experts,
131+
) # [E+1] (added leading 0 so can be used for indexing)
132+
# num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1)
129133
cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(
130134
torch.int64
131135
) # [E+1]

0 commit comments

Comments
 (0)