Skip to content

Commit 65829a2

Browse files
cthifacebook-github-bot
authored andcommitted
Fix scaled_mm_rowwise in quantize_bench (pytorch#4551)
Summary: X-link: facebookresearch/FBGEMM#1594 When looking at this with samanamp recently I noticed the scaled_mm is not actually using the rowwise scaling with it, I think this was before it was supported properly. We also add support for compile, which will be useful for testing. Reviewed By: jianyuh Differential Revision: D78844879
1 parent a7774fe commit 65829a2

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def benchmark_grouped(
169169
trace: bool = False,
170170
num_iters: int = 1,
171171
fast_accum: bool = True,
172+
torch_compile: bool = False,
172173
) -> Dict[str, Any]:
173174
num_groups = len(m)
174175
# Create input tensors.
@@ -197,6 +198,8 @@ def benchmark_grouped(
197198
# Set fast accum mode if applicable.
198199
if hasattr(quantize_op, "fast_accum"):
199200
quantize_op.fast_accum = fast_accum
201+
if hasattr(quantize_op, "torch_compile"):
202+
quantize_op.torch_compile = torch_compile
200203
# Get the quantized tensors for this operator.
201204
preprocessed_args = quantize_op.preprocess(A, B)
202205
quantized_vals = quantize_op.quantize(*preprocessed_args)
@@ -282,6 +285,7 @@ def benchmark(
282285
trace: bool = False,
283286
num_iters: int = 1,
284287
fast_accum: bool = True,
288+
torch_compile: bool = False,
285289
) -> Dict[str, Any]:
286290
# Create input tensors.
287291
if b > 1:
@@ -301,6 +305,8 @@ def benchmark(
301305
# Set fast accum mode if applicable.
302306
if hasattr(quantize_op, "fast_accum"):
303307
quantize_op.fast_accum = fast_accum
308+
if hasattr(quantize_op, "torch_compile"):
309+
quantize_op.torch_compile = torch_compile
304310
# Preprocess data if needed.
305311
preprocessed_args = quantize_op.preprocess(A, B)
306312
# Get the quantized tensors for this operator.
@@ -495,6 +501,7 @@ def main(args: Any):
495501
args.trace,
496502
args.num_iters,
497503
not args.disable_fast_accum,
504+
args.torch_compile,
498505
)
499506
benchmark_results.append(quantize_measurements)
500507
if args.export_csv or args.plot:
@@ -625,6 +632,12 @@ def invoke_main() -> None:
625632
action="store_true",
626633
help="If set, disable fast accumulation for FP8 implementations.",
627634
)
635+
parser.add_argument(
636+
"--torch_compile",
637+
default=False,
638+
action="store_true",
639+
help="If set, torch.compile will be used for scaled_mm backed ops.",
640+
)
628641

629642
args = parser.parse_args()
630643
main(args)

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,27 +345,35 @@ def cuda(self) -> bool:
345345
class ScaledMMRowwise(QuantizeOpBase):
346346
def __init__(self):
347347
self.fast_accum = True
348+
self.torch_compile = False
348349

349350
def quantize(self, x, w):
350351
xq, x_scale = quantize_fp8_row(x)
351352
wq, w_scale = quantize_fp8_row(w)
352-
dummy_scale = torch.tensor([1.0], device=x.device, dtype=torch.float32)
353-
return xq, wq.t(), x_scale, w_scale, dummy_scale
353+
return xq, wq.t(), x_scale.unsqueeze(1), w_scale.unsqueeze(0)
354354

355-
def compute(self, xq, wq, x_scale, w_scale, dummy_scale):
356-
output = torch._scaled_mm(
355+
def compute(self, xq, wq, x_scale, w_scale):
356+
if self.torch_compile:
357+
f = torch.compile(
358+
torch._scaled_mm,
359+
options={
360+
"max_autotune": True,
361+
"max_autotune_gemm_backends": "TRITON,CK,CUTLASS,ATEN",
362+
},
363+
)
364+
else:
365+
f = torch._scaled_mm
366+
367+
return f(
357368
xq,
358369
wq,
359370
bias=None,
360371
out_dtype=torch.bfloat16,
361-
scale_a=dummy_scale,
362-
scale_b=dummy_scale,
372+
scale_a=x_scale,
373+
scale_b=w_scale,
363374
scale_result=None,
364375
use_fast_accum=self.fast_accum,
365376
)
366-
# Apply separate rowwise scaling.
367-
output = scale_fp8_row(output, x_scale, w_scale)
368-
return output
369377

370378
def quantize_and_compute(self, x, w):
371379
return self.compute(*self.quantize(x, w))

0 commit comments

Comments
 (0)