Skip to content

Commit c2660fd

Browse files
committed
fixing CI
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f316a1c commit c2660fd

File tree

5 files changed

+13
-4
lines changed

5 files changed

+13
-4
lines changed

torchao/_models/mixtral-moe/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
This is the benchmarking setup primarily used for testing quantized moe. You can reproduce the above numbers by running
2+
3+
`sh scripts/prepare.sh`

torchao/_models/mixtral-moe/scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
3737
parser.add_argument(
3838
"--repo_id",
3939
type=str,
40-
default="checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1",
40+
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
4141
help="Repository ID to download from.",
4242
)
4343
parser.add_argument(
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python scripts/download.py --repo_id mistralai/Mixtral-8x7B-Instruct-v0.1
2+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-v0.1

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,9 @@ def test_moe_quant_intx(self):
646646
from torchao.quantization.utils import compute_error
647647

648648
with torch.device("cpu"):
649-
model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to(torch.bfloat16)
649+
model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to(
650+
torch.bfloat16
651+
)
650652
x = torch.randn(8, 512, dtype=torch.bfloat16)
651653

652654
out = model(x).clone()

torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ def forward(
120120
ordered_token_indices = (
121121
ordered_token_activations.div(top_k).floor().to(torch.int64)
122122
) # [T]
123-
if not expert_indices.is_cuda: # histc doesn't work on cpu for integers
124-
num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1)
123+
if not expert_indices.is_cuda: # histc doesn't work on cpu for integers
124+
num_tokens_per_expert = torch.bincount(
125+
expert_indices.view(-1) + 1, minlength=self.num_experts + 1
126+
)
125127
else:
126128
num_tokens_per_expert = torch.histc(
127129
expert_indices,

0 commit comments

Comments
 (0)