Skip to content

Commit 3ca7041

Browse files
authored
Fix remaining flake8 and lint in titan repo (#1382)
`pre-commit run --all-files` is failing due to a few files. This cleans it up so the entire repo passes lint checks
1 parent 5375abb commit 3ca7041

File tree

6 files changed

+24
-42
lines changed

6 files changed

+24
-42
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ We actively welcome your pull requests.
1414
2. If you've added code that should be tested, add tests.
1515
3. If you've changed APIs, update the documentation.
1616
4. Ensure the test suite passes.
17-
5. Make sure your code lints (`pre-commit run --files $(git diff --name-only HEAD~1)`).
17+
5. Make sure your code lints (`pre-commit run --all-files`).
1818
6. If you haven't already, complete the Contributor License Agreement ("CLA").
1919

2020
### Contributor License Agreement ("CLA")

torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# pyre-unsafe
88
import logging
99

10-
import numpy as np
1110
import torch
1211

1312
from reference_utils import (

torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,21 @@
88
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
99

1010
# pyre-unsafe
11-
import functools
1211
import logging
1312

1413
import os
1514
import sys
16-
from typing import Any, Dict, Optional, Tuple
15+
from typing import Tuple
1716

1817
import torch
1918

2019
import triton
2120
import triton.language as tl
22-
from triton import Config as TConfig
23-
24-
from triton.runtime import driver # @manual
2521

2622
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
2723

2824
from tma_autotuning import (
2925
_NV_CONFIGS,
30-
ALIGN_SIZE_M,
3126
CudaUtils,
3227
early_config_prune,
3328
TmaDescriptorHelper,
@@ -727,6 +722,7 @@ def grouped_gemm_forward(
727722
w: torch.Tensor,
728723
m_sizes: torch.Tensor,
729724
tma_size: int = 128,
725+
using_fp8: bool = False,
730726
) -> torch.Tensor:
731727
"""
732728
M*G style grouped GEMM with TMA and Float8 support.
@@ -892,7 +888,7 @@ def grouped_gemm_backward(
892888

893889
# Compute grad_x using flat linear implementation
894890
try:
895-
logging.info(f"Computing grad_x with flat linear kernel")
891+
logging.info("Computing grad_x with flat linear kernel")
896892

897893
# Use TMA-optimized implementation
898894
grad_x = grouped_gemm_dx_tma(
@@ -909,7 +905,7 @@ def grouped_gemm_backward(
909905

910906
# Compute grad_w using flat linear style implementation
911907
try:
912-
logging.info(f"Computing grad_w with flat linear kernel")
908+
logging.info("Computing grad_w with flat linear kernel")
913909

914910
grad_w = grouped_gemm_dw_tma(
915911
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
@@ -1203,14 +1199,14 @@ def grid(META):
12031199
# ======== PyTorch wrapper functions ========
12041200

12051201

1206-
class GroupedGEMM_mg(torch.autograd.Function):
1202+
class GroupedGemmMg(torch.autograd.Function):
12071203
"""
12081204
Autograd function for GroupedGEMM with M*G grouping.
12091205
Supports both standard and FP8 quantized operations.
12101206
"""
12111207

12121208
@staticmethod
1213-
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1209+
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
12141210
"""
12151211
Forward pass of GroupedGEMM.
12161212
@@ -1301,4 +1297,4 @@ def mg_grouped_gemm(
13011297
Returns:
13021298
Output tensor, shape [M_total, N]
13031299
"""
1304-
return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
1300+
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)

torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@
88
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
99

1010
# pyre-unsafe
11-
import functools
1211

1312
import os
1413
import sys
15-
from typing import Any, Dict, Optional, Tuple
14+
from typing import Dict
1615

1716
import torch
1817

1918
import triton
2019
import triton.language as tl
21-
from triton import Config as TConfig
2220

2321
from triton.runtime import driver # @manual
2422

@@ -54,7 +52,11 @@ def get_num_sms() -> int:
5452

5553

5654
class TmaDescriptorHelper:
57-
"""Helper class for managing TMA descriptors in Triton kernels."""
55+
"""Helper class for managing TMA descriptors in Triton kernels.
56+
57+
Args:
58+
tma_size: Size of the TMA descriptor in bytes
59+
"""
5860

5961
class KernelParamWrapper:
6062
"""Wrapper to implement the TmaDescKernelParam interface."""
@@ -67,11 +69,6 @@ def tma_desc_cpu_ptr(self) -> int:
6769
return self.desc.data_ptr()
6870

6971
def __init__(self, tma_size: int = 128):
70-
"""Initialize the TMA descriptor helper.
71-
72-
Args:
73-
tma_size: Size of the TMA descriptor in bytes
74-
"""
7572
if not CudaUtils.verify_tma():
7673
raise RuntimeError(
7774
"TMA not supported on this device (requires Hopper or newer)"

torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
8-
import logging
98
import unittest
109
from typing import Tuple
1110

1211
import torch
13-
import torch.nn as nn
14-
15-
from mg_grouped_gemm import (
16-
grouped_gemm_backward,
17-
grouped_gemm_dw_tma,
18-
grouped_gemm_dx_tma,
19-
grouped_gemm_forward,
20-
mg_grouped_gemm,
21-
)
12+
13+
from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward
2214

2315
from reference_utils import (
2416
analyze_tensor_differences,
@@ -27,7 +19,7 @@
2719
)
2820

2921

30-
class TestMG_GroupedGEMM_Backward(unittest.TestCase):
22+
class TestMgGroupedGemmBackward(unittest.TestCase):
3123
def setUp(self) -> None:
3224
torch.manual_seed(2020) # Set seed for reproducibility
3325

@@ -81,7 +73,7 @@ def _run_grouped_gemm_backward_test(
8173
self.assertTrue(grad_a_close)
8274
self.assertTrue(grad_b_close)
8375

84-
def test_MG_grouped_gemm_backward_bf16(self) -> None:
76+
def test_mg_grouped_gemm_backward_bf16(self) -> None:
8577
for G in (1, 8, 16):
8678
for M in (512, 1024):
8779
print(f"Testing BF16 M*G GroupGeMM Backward with G={G}, M={M}")
@@ -93,7 +85,7 @@ def test_MG_grouped_gemm_backward_bf16(self) -> None:
9385
rtol=1e-2,
9486
)
9587

96-
def test_MG_grouped_gemm_backward_deepseek_shapes(self) -> None:
88+
def test_mg_grouped_gemm_backward_deepseek_shapes(self) -> None:
9789
"""Test backward pass with shapes from Deepseek model."""
9890
deepseek_shapes = [
9991
(4, 2048, 4096, 7168), # G, M, N, K
@@ -113,7 +105,7 @@ def test_MG_grouped_gemm_backward_deepseek_shapes(self) -> None:
113105
shape, device, dtype=torch.float16, atol=1e-2, rtol=1e-2
114106
)
115107

116-
def test_MG_dx(self) -> None:
108+
def test_mg_dx(self) -> None:
117109
"""Test specifically the dx (gradient w.r.t. input) computation."""
118110
G, M, N, K = 4, 512, 1024, 2048
119111
device = torch.device("cuda")
@@ -143,7 +135,7 @@ def test_MG_dx(self) -> None:
143135
dx_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_a (dx)")
144136
self.assertTrue(dx_close)
145137

146-
def test_MG_dw(self) -> None:
138+
def test_mg_dw(self) -> None:
147139
"""Test specifically the dw (gradient w.r.t. weights) computation."""
148140
G, M, N, K = 4, 512, 1024, 2048
149141
device = torch.device("cuda")

torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
8-
import logging
98
import unittest
109
from typing import Tuple
1110

1211
import torch
13-
import torch.nn as nn
1412

1513
from mg_grouped_gemm import grouped_gemm_forward
1614

1715

18-
class TestMG_GroupedGEMM(unittest.TestCase):
16+
class TestMgGroupedGemm(unittest.TestCase):
1917
def setUp(self) -> None:
2018
torch.manual_seed(2020)
2119

@@ -51,7 +49,7 @@ def _run_grouped_gemm_test(
5149
result = result.to(dtype)
5250
torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol)
5351

54-
def test_MG_grouped_gemm_bf16(self) -> None:
52+
def test_mg_grouped_gemm_bf16(self) -> None:
5553
for G in (1, 4, 16):
5654
for M in (128, 512, 1024):
5755
print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}")
@@ -63,7 +61,7 @@ def test_MG_grouped_gemm_bf16(self) -> None:
6361
rtol=1.6e-2,
6462
)
6563

66-
def test_MG_grouped_gemm_deepseek_shapes(self) -> None:
64+
def test_mg_grouped_gemm_deepseek_shapes(self) -> None:
6765
"""Test with shapes from Deepseek model."""
6866
deepseek_shapes = [
6967
(4, 2048, 4096, 7168), # G, M, N, K

0 commit comments

Comments
 (0)