Skip to content

[MoE training] torch.compile support for ScaledGroupedMMTensor #2509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
["does.not.exist"],
],
)
def test_moe_float8_training(target_fqns: list[str]):
@pytest.mark.parametrize(
"compile", [False, True]
)
def test_moe_float8_training(target_fqns: list[str], compile: bool):
model_args = TransformerModelArgs(
moe_enabled=True,
num_experts=8,
Expand Down Expand Up @@ -73,6 +76,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
target_fqns=target_fqns,
)

if compile:
model = torch.compile(model, fullgraph=False)
ref_model = torch.compile(ref_model, fullgraph=False)

# inputs
batch, seq, dim = 8, 2048, 256
ref_x = torch.randn(
Expand Down
10 changes: 7 additions & 3 deletions torchao/prototype/moe_training/kernels/jagged_float8_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
for block_size_cols in block_sizes
]

from torch.library import triton_op, wrap_triton



@triton_op("torchao::triton_fp8_row_major_jagged_rowwise_scales", mutates_args={})
def triton_fp8_row_major_jagged_rowwise_scales(
hp_tensor: torch.Tensor,
offsets: torch.Tensor,
Expand Down Expand Up @@ -90,7 +94,7 @@ def triton_fp8_row_major_jagged_rowwise_scales(
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
offsets.numel(),
)
_triton_fp8_row_major_jagged_rowwise_scales[grid](
wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid](
hp_tensor,
offsets,
output_buffer,
Expand Down Expand Up @@ -203,7 +207,7 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
)
tl.store(out_ptr + out_offs, fp8_data, mask=block_mask)


@triton_op("torchao::triton_fp8_col_major_jagged_colwise_scales", mutates_args={})
def triton_fp8_col_major_jagged_colwise_scales(
hp_tensor: torch.Tensor,
offsets: torch.Tensor,
Expand Down Expand Up @@ -251,7 +255,7 @@ def triton_fp8_col_major_jagged_colwise_scales(
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
offsets.numel(),
)
_triton_fp8_col_major_jagged_colwise_scales[grid](
wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid](
hp_tensor,
offsets,
output_buffer,
Expand Down
2 changes: 0 additions & 2 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def _scaled_grouped_mm(
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
"""
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
logger.info("Using scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __repr__(self):
return f"ScaledGroupedMMTensor(data={self._data})"

def __tensor_flatten__(self):
return ["_data"]
# Metadata is empty but needed to make the subclass traceable for torch.compile.
metadata = {}
return ["_data"], metadata

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
Expand Down
Loading