Skip to content

Commit eb49ded

Browse files
committed
fixing CI
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent bd9f7d9 commit eb49ded

File tree

6 files changed

+33
-17
lines changed

6 files changed

+33
-17
lines changed

torchao/_models/mixtral-moe/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1-
This is the benchmarking setup primarily used for testing quantized moe. You can reproduce the above numbers by running
1+
## Mixtral-MoE
2+
3+
This folder contains code and scripts for benchmarking the Mixtral-MoE model.
4+
Running
25

36
`sh scripts/prepare.sh`
7+
8+
should download the model and `sh run.sh` will run teh benchmarks.

torchao/_models/mixtral-moe/generate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def main(
208208
assert checkpoint_path.is_file(), checkpoint_path
209209
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
210210
assert tokenizer_path.is_file(), str(tokenizer_path)
211-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
212211
print(f"Using device={device}")
213212
precision = torch.bfloat16
214213
is_chat = "chat" in str(checkpoint_path)
@@ -220,10 +219,10 @@ def main(
220219

221220
print("Loading model ...")
222221
t0 = time.time()
223-
model = _load_model(checkpoint_path, device, precision)
222+
model = _load_model(checkpoint_path, "cpu", precision)
224223

225-
device_sync(device=device) # MKG
226224
print(f"Time to load model: {time.time() - t0:.02f} seconds")
225+
t0 = time.time()
227226

228227
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
229228
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
@@ -299,7 +298,12 @@ def main(
299298

300299
if config is not None:
301300
quantize_(model, config, filter_fn=cond_ffn_filter)
302-
torch.cuda.reset_peak_memory_stats()
301+
print(f"Time to apply quantization to model: {time.time() - t0:.02f} seconds")
302+
303+
model.to(device=device)
304+
device_sync(device=device)
305+
306+
print(f"C: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
303307

304308
if compile:
305309
# moe quant + compile causes repeated warnings
@@ -382,7 +386,7 @@ def callback(x):
382386

383387
if not interactive:
384388
pass
385-
print(tokenizer.decode(y[0].tolist()))
389+
# print(tokenizer.decode(y[0].tolist()))
386390
else:
387391
print()
388392
tokens_generated = y.size(-1) - prompt_length

torchao/_models/mixtral-moe/model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,17 @@ def forward(
395395
.to(torch.int64)
396396
) # [T]
397397

398-
num_tokens_per_expert = torch.histc(
399-
expert_indices, bins=self.num_experts + 1, min=-1, max=self.num_experts
400-
) # [E+1] (added leading 0 so can be used for indexing)
398+
if not expert_indices.is_cuda: # histc doesn't work on cpu for integers
399+
num_tokens_per_expert = torch.bincount(
400+
expert_indices.view(-1) + 1, minlength=self.num_experts + 1
401+
)
402+
else:
403+
num_tokens_per_expert = torch.histc(
404+
expert_indices,
405+
bins=self.num_experts + 1,
406+
min=-1,
407+
max=self.num_experts,
408+
) # [E+1] (added leading 0 so can be used for indexing)
401409
cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(
402410
torch.int64
403411
) # [E+1]

torchao/_models/mixtral-moe/run.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1
2-
export CHECKPOINT_PATH=~/checkpoints/
2+
export CHECKPOINT_PATH=checkpoints/
33

44
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --compile
55
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --compile
@@ -16,11 +16,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --ba
1616
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int4wo-base --compile
1717
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int4wo-base --compile
1818

19-
# EXPERT CHOICE
20-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile
21-
# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile
22-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile
23-
# # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile
19+
# # EXPERT CHOICE
20+
# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq --compile
21+
# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq --compile
22+
# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8dq-base --compile
23+
# # # # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant int8dq-base --compile
2424

2525
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant fp8wo --compile
2626
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 8 --moe_quant fp8wo --compile
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
python scripts/download.py --repo_id mistralai/Mixtral-8x7B-Instruct-v0.1
2-
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-v0.1
2+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/mistralai/Mixtral-8x7B-Instruct-v0.1

torchao/quantization/prototype/moe_quant/quantizable_moe_modules.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def forward(
131131
min=-1,
132132
max=self.num_experts,
133133
) # [E+1] (added leading 0 so can be used for indexing)
134-
# num_tokens_per_expert = torch.bincount(expert_indices.view(-1)+1, minlength=self.num_experts+1)
135134
cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(
136135
torch.int64
137136
) # [E+1]

0 commit comments

Comments
 (0)