Skip to content

Commit 89ec74b

Browse files
committed
fixing generate.py device stuff
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 16bc60d commit 89ec74b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

torchao/_models/mixtral-moe/generate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,10 @@ def main(
271271
config = Int4WeightOnlyConfig()
272272

273273
elif "int4wo" in moe_quant:
274-
config = MoEQuantConfig(Float8WeightOnlyConfig())
274+
config = MoEQuantConfig(Int4WeightOnlyConfig())
275275

276276
elif "fp8wo-base" in moe_quant:
277-
config = Int4WeightOnlyConfig()
277+
config = Float8WeightOnlyConfig()
278278

279279
elif "fp8wo" in moe_quant:
280280
config = MoEQuantConfig(Float8WeightOnlyConfig())
@@ -297,7 +297,7 @@ def main(
297297
)
298298

299299
if config is not None:
300-
quantize_(model, config, filter_fn=cond_ffn_filter)
300+
quantize_(model, config, filter_fn=cond_ffn_filter, device=device)
301301
print(
302302
f"Time to apply quantization to model: {time.time() - t0:.02f} seconds"
303303
)
@@ -392,10 +392,10 @@ def callback(x):
392392
tokens_generated = y.size(-1) - prompt_length
393393
tokens_sec = tokens_generated / t
394394
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
395-
print(
396-
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
397-
)
398-
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
395+
# print(
396+
# f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
397+
# )
398+
# print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
399399

400400
if i == 0 and device == "cuda" and memory_profile is not None:
401401
snapshot = torch.cuda.memory._snapshot()

torchao/dtypes/floatx/float8_layout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
159159
raise ValueError(
160160
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
161161
)
162-
elif func in [aten.select.int, func is aten.index.Tensor]:
162+
elif func in [aten.select.int, aten.index.Tensor]:
163163
return return_and_correct_aliasing(
164164
func,
165165
args,

0 commit comments

Comments
 (0)