@@ -271,10 +271,10 @@ def main(
271
271
config = Int4WeightOnlyConfig ()
272
272
273
273
elif "int4wo" in moe_quant :
274
- config = MoEQuantConfig (Float8WeightOnlyConfig ())
274
+ config = MoEQuantConfig (Int4WeightOnlyConfig ())
275
275
276
276
elif "fp8wo-base" in moe_quant :
277
- config = Int4WeightOnlyConfig ()
277
+ config = Float8WeightOnlyConfig ()
278
278
279
279
elif "fp8wo" in moe_quant :
280
280
config = MoEQuantConfig (Float8WeightOnlyConfig ())
@@ -297,7 +297,7 @@ def main(
297
297
)
298
298
299
299
if config is not None :
300
- quantize_ (model , config , filter_fn = cond_ffn_filter )
300
+ quantize_ (model , config , filter_fn = cond_ffn_filter , device = device )
301
301
print (
302
302
f"Time to apply quantization to model: { time .time () - t0 :.02f} seconds"
303
303
)
@@ -392,10 +392,10 @@ def callback(x):
392
392
tokens_generated = y .size (- 1 ) - prompt_length
393
393
tokens_sec = tokens_generated / t
394
394
aggregate_metrics ["tokens_per_sec" ].append (tokens_sec )
395
- print (
396
- f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec"
397
- )
398
- print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
395
+ # print(
396
+ # f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
397
+ # )
398
+ # print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
399
399
400
400
if i == 0 and device == "cuda" and memory_profile is not None :
401
401
snapshot = torch .cuda .memory ._snapshot ()
0 commit comments