Skip to content

Commit 63f2e51

Browse files
authored
[reland2][ROCm] preshuffled weight mm (#2207)
* [reland2][ROCm] preshuffled weight mm * remove debug print statements * remove duplicate registrations caused by patch fuzzing * lint * ruff
1 parent b0cfeec commit 63f2e51

File tree

6 files changed

+1173
-52
lines changed

6 files changed

+1173
-52
lines changed

setup.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ def use_debug_mode():
8383
_get_cuda_arch_flags,
8484
)
8585

86-
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
87-
8886

8987
class BuildOptions:
9088
def __init__(self):
@@ -280,30 +278,35 @@ def get_extensions():
280278
if debug_mode:
281279
print("Compiling in debug mode")
282280

283-
if not torch.version.cuda:
284-
print(
285-
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
286-
)
287-
if (CUDA_HOME is None and ROCM_HOME is None) and torch.version.cuda:
288-
print(
289-
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
290-
)
281+
if CUDA_HOME is None and torch.version.cuda:
282+
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
291283
print(
292284
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
293285
)
294-
295-
use_cuda = torch.version.cuda and (CUDA_HOME is not None or ROCM_HOME is not None)
296-
extension = CUDAExtension if use_cuda else CppExtension
286+
if ROCM_HOME is None and torch.version.hip:
287+
print("ROCm is not available. Skipping compilation of ROCm extensions")
288+
print("If you'd like to compile ROCm extensions locally please install ROCm")
289+
290+
use_cuda = torch.version.cuda and CUDA_HOME is not None
291+
use_rocm = torch.version.hip and ROCM_HOME is not None
292+
extension = CUDAExtension if (use_cuda or use_rocm) else CppExtension
293+
294+
nvcc_args = [
295+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
296+
"-O3" if not debug_mode else "-O0",
297+
"-t=0",
298+
"-std=c++17",
299+
]
300+
rocm_args = [
301+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
302+
"-O3" if not debug_mode else "-O0",
303+
"-std=c++17",
304+
]
297305

298306
extra_link_args = []
299307
extra_compile_args = {
300308
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
301-
"nvcc": [
302-
"-DNDEBUG" if not debug_mode else "-DDEBUG",
303-
"-O3" if not debug_mode else "-O0",
304-
"-t=0",
305-
"-std=c++17",
306-
],
309+
"nvcc": nvcc_args if use_cuda else rocm_args,
307310
}
308311

309312
if not IS_WINDOWS:
@@ -341,6 +344,34 @@ def get_extensions():
341344
extra_compile_args["nvcc"].append("-g")
342345
extra_link_args.append("/DEBUG")
343346

347+
rocm_sparse_marlin_supported = False
348+
if use_rocm:
349+
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
350+
found_col16 = False
351+
found_vec_ext = False
352+
print("ROCM_HOME", ROCM_HOME)
353+
hipblaslt_headers = list(
354+
glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))
355+
)
356+
print("hipblaslt_headers", hipblaslt_headers)
357+
for header in hipblaslt_headers:
358+
with open(header) as f:
359+
text = f.read()
360+
if "HIPBLASLT_ORDER_COL16" in text:
361+
found_col16 = True
362+
if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text:
363+
found_vec_ext = True
364+
if found_col16:
365+
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
366+
print("hipblaslt found extended col order enums")
367+
else:
368+
print("hipblaslt does not have extended col order enums")
369+
if found_vec_ext:
370+
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
371+
print("hipblaslt found vec ext")
372+
else:
373+
print("hipblaslt does not have vec ext")
374+
344375
# Get base directory and source paths
345376
curdir = os.path.dirname(os.path.curdir)
346377
extensions_dir = os.path.join(curdir, "torchao", "csrc")
@@ -354,42 +385,46 @@ def get_extensions():
354385
)
355386
sources = [s for s in sources if s not in excluded_sources]
356387

388+
# Collect CUDA source files
357389
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
358390
cuda_sources = list(
359391
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
360392
)
361393

362-
# Define HIP source directories
363-
hip_source_dirs = [
394+
# Define ROCm source directories
395+
rocm_source_dirs = [
396+
os.path.join(extensions_dir, "rocm", "swizzle"),
364397
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout"),
365-
# TODO: Add sparse_marlin back in once we have a ROCm build for it
366-
# os.path.join(extensions_dir, "cuda", "sparse_marlin")
367398
]
368-
369-
# Collect all HIP sources from the defined directories
370-
hip_sources = []
371-
for hip_dir in hip_source_dirs:
372-
hip_sources.extend(glob.glob(os.path.join(hip_dir, "*.cu"), recursive=True))
373-
374-
# Collect CUDA source files if needed
375-
if not IS_ROCM and use_cuda:
399+
if rocm_sparse_marlin_supported:
400+
rocm_source_dirs.extend([os.path.join(extensions_dir, "cuda", "sparse_marlin")])
401+
402+
# Collect all ROCm sources from the defined directories
403+
rocm_sources = []
404+
for rocm_dir in rocm_source_dirs:
405+
rocm_sources.extend(glob.glob(os.path.join(rocm_dir, "*.cu"), recursive=True))
406+
rocm_sources.extend(glob.glob(os.path.join(rocm_dir, "*.hip"), recursive=True))
407+
rocm_sources.extend(glob.glob(os.path.join(rocm_dir, "*.cpp"), recursive=True))
408+
409+
# Add CUDA source files if needed
410+
if use_cuda:
376411
sources += cuda_sources
377412

378413
# TOOD: Remove this and use what CUDA has once we fix all the builds.
379-
if IS_ROCM and use_cuda:
414+
if use_rocm:
380415
# Add ROCm GPU architecture check
381-
gpu_arch = torch.cuda.get_device_properties(0).name
382-
if gpu_arch != "gfx942":
416+
gpu_arch = None
417+
if torch.cuda.is_available():
418+
gpu_arch = torch.cuda.get_device_properties(0).name
419+
if gpu_arch and gpu_arch != "gfx942":
383420
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
384-
print(
385-
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
386-
)
387-
else:
388-
sources += hip_sources
421+
print("Currently only gfx942 is supported. Compiling only for gfx942.")
422+
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
423+
sources += rocm_sources
389424

390425
use_cutlass = False
391426
cutlass_90a_sources = None
392-
if use_cuda and not IS_ROCM and not IS_WINDOWS:
427+
if use_cuda and not IS_WINDOWS:
393428
use_cutlass = True
394429
cutlass_dir = os.path.join(third_party_path, "cutlass")
395430
cutlass_include_dir = os.path.join(cutlass_dir, "include")

test/test_ops.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
compute_max_diff,
3232
)
3333

34-
if torch.version.hip is not None:
35-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
34+
IS_CUDA = torch.cuda.is_available() and torch.version.cuda
35+
IS_ROCM = torch.cuda.is_available() and torch.version.hip
3636

3737
try:
3838
import torchao.ops
@@ -58,7 +58,7 @@ def _create_floatx_inputs(
5858
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
5959
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
6060

61-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
61+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
6262
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
6363
@parametrize("dtype", [torch.half, torch.bfloat16])
6464
def test_quant_llm_linear(self, ebits, mbits, dtype):
@@ -88,7 +88,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
8888
test_utils=test_utils,
8989
)
9090

91-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
91+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
9292
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
9393
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
9494
@parametrize("dtype", [torch.half, torch.bfloat16])
@@ -278,7 +278,7 @@ def make_test_id(param):
278278
return f"tiles_{param}"
279279

280280

281-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
281+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
282282
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
283283
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
284284
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
@@ -296,7 +296,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
296296

297297

298298
# TODO: Fix "test_aot_dispatch_dynamic" test failure
299-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
299+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
300300
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
301301
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
302302
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
@@ -342,7 +342,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
342342
return dq.reshape(n, k)
343343

344344

345-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
345+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
346346
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
347347
@pytest.mark.parametrize(
348348
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -410,7 +410,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
410410

411411

412412
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
413-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
413+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
414414
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
415415
@pytest.mark.parametrize(
416416
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -476,7 +476,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
476476
assert diff_op_ao < 1e-1
477477

478478

479-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
479+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
480480
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
481481
@pytest.mark.parametrize(
482482
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -587,7 +587,7 @@ def reshape_w(w):
587587
)
588588

589589

590-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
590+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
591591
@pytest.mark.parametrize(
592592
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
593593
MARLIN_TEST_PARAMS,
@@ -677,7 +677,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
677677
)
678678

679679

680-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
680+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
681681
@pytest.mark.parametrize(
682682
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
683683
MARLIN_TEST_PARAMS,
@@ -756,5 +756,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
756756
)
757757

758758

759+
@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
760+
def test_swizzle_mm():
761+
test_utils = [
762+
"test_schema",
763+
"test_autograd_registration",
764+
"test_faketensor",
765+
]
766+
767+
# TODO: Figure out why test fails unless torch >= 2.5
768+
if TORCH_VERSION_AT_LEAST_2_5:
769+
test_utils.append("test_aot_dispatch_dynamic")
770+
771+
mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda")
772+
mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda")
773+
774+
opcheck(
775+
torch.ops.torchao.swizzle_mm,
776+
(mat1, mat2, False, False),
777+
test_utils=test_utils,
778+
)
779+
780+
759781
if __name__ == "__main__":
760782
pytest.main(sys.argv)

torchao/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, quantization, testing
46+
from . import dtypes, optim, quantization, swizzle, testing
4747

4848
__all__ = [
4949
"dtypes",
5050
"autoquant",
5151
"optim",
5252
"quantize_",
53+
"swizzle",
5354
"testing",
5455
"ops",
5556
"quantization",

0 commit comments

Comments
 (0)