From 567e4fb2e26bcaffffeabdb4675625db190fbf7b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 8 Jul 2025 14:34:24 -0700 Subject: [PATCH] torch.compile support for ScaledGroupedMMTensor --- test/prototype/moe_training/test_training.py | 9 ++++++++- .../moe_training/kernels/jagged_float8_scales.py | 10 +++++++--- torchao/prototype/moe_training/scaled_grouped_mm.py | 2 -- torchao/prototype/moe_training/tensor.py | 4 +++- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 7087d1d571..5e1dce066b 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -35,7 +35,10 @@ ["does.not.exist"], ], ) -def test_moe_float8_training(target_fqns: list[str]): +@pytest.mark.parametrize( + "compile", [False, True] +) +def test_moe_float8_training(target_fqns: list[str], compile: bool): model_args = TransformerModelArgs( moe_enabled=True, num_experts=8, @@ -73,6 +76,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: target_fqns=target_fqns, ) + if compile: + model = torch.compile(model, fullgraph=False) + ref_model = torch.compile(ref_model, fullgraph=False) + # inputs batch, seq, dim = 8, 2048, 256 ref_x = torch.randn( diff --git a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py index 3a497bf4a6..fe8817dd6f 100644 --- a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py +++ b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py @@ -42,7 +42,11 @@ for block_size_cols in block_sizes ] +from torch.library import triton_op, wrap_triton + + +@triton_op("torchao::triton_fp8_row_major_jagged_rowwise_scales", mutates_args={}) def triton_fp8_row_major_jagged_rowwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, @@ -90,7 +94,7 @@ def triton_fp8_row_major_jagged_rowwise_scales( triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]), offsets.numel(), ) - _triton_fp8_row_major_jagged_rowwise_scales[grid]( + wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid]( hp_tensor, offsets, output_buffer, @@ -203,7 +207,7 @@ def _triton_fp8_row_major_jagged_rowwise_scales( ) tl.store(out_ptr + out_offs, fp8_data, mask=block_mask) - +@triton_op("torchao::triton_fp8_col_major_jagged_colwise_scales", mutates_args={}) def triton_fp8_col_major_jagged_colwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, @@ -251,7 +255,7 @@ def triton_fp8_col_major_jagged_colwise_scales( triton.cdiv(n, meta["BLOCK_SIZE_COLS"]), offsets.numel(), ) - _triton_fp8_col_major_jagged_colwise_scales[grid]( + wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid]( hp_tensor, offsets, output_buffer, diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index d9ccdcba03..2ec8ae40ba 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -40,8 +40,6 @@ def _scaled_grouped_mm( offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - # TODO: Remove once prototype is more mature. This is currently very useful for development and debugging. - logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index d6fce479d4..f3f4a3ce00 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -123,7 +123,9 @@ def __repr__(self): return f"ScaledGroupedMMTensor(data={self._data})" def __tensor_flatten__(self): - return ["_data"] + # Metadata is empty but needed to make the subclass traceable for torch.compile. + metadata = {} + return ["_data"], metadata @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):