@@ -272,18 +272,15 @@ def get_cutlass_build_flags():
272
272
raise ValueError ("No CUDA version found" )
273
273
274
274
major , minor = map (int , cuda_version .split ("." )[:2 ])
275
- build_sm90a = (major , minor ) >= (12 , 6 )
276
- build_sm100a = (major , minor ) >= (12 , 8 )
277
- build_sm120a = (major , minor ) >= (12 , 8 )
275
+ build_sm90a = major > 12 or (major == 12 and minor >= 6 )
276
+ build_sm100a = major > 12 or (major == 12 and minor >= 8 )
278
277
279
278
if build_sm90a :
280
279
print (f"CUDA { cuda_version } : Enabling SM90a CUTLASS kernels" )
281
280
if build_sm100a :
282
281
print (f"CUDA { cuda_version } : Enabling SM100a CUTLASS kernels" )
283
- if build_sm120a :
284
- print (f"CUDA { cuda_version } : Enabling SM120a CUTLASS kernels" )
285
282
286
- return build_sm90a , build_sm100a , build_sm120a
283
+ return build_sm90a , build_sm100a
287
284
except :
288
285
# Fallback to architecture flags
289
286
cuda_arch_flags = _get_cuda_arch_flags ()
@@ -343,11 +340,6 @@ def __init__(
343
340
self .cmake_args = cmake_args
344
341
345
342
346
- def remove_items (a : list , b : list ) -> list :
347
- """Remove items in list b from list a"""
348
- return [x for x in a if x not in b ]
349
-
350
-
351
343
def get_extensions ():
352
344
# Skip building C++ extensions if USE_CPP is set to "0"
353
345
if use_cpp == "0" :
@@ -462,7 +454,7 @@ def get_extensions():
462
454
excluded_sources = list (
463
455
glob .glob (os .path .join (extensions_dir , "cpu/*.cpp" ), recursive = True )
464
456
)
465
- sources = remove_items ( sources , excluded_sources )
457
+ sources = [ s for s in sources if s not in excluded_sources ]
466
458
467
459
# Collect CUDA source files
468
460
extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
@@ -506,24 +498,22 @@ def get_extensions():
506
498
rocm_sources = list (
507
499
glob .glob (os .path .join (extensions_rocm_dir , "**/*.cpp" ), recursive = True )
508
500
)
509
- sources = remove_items ( sources , rocm_sources )
501
+ sources = [ s for s in sources if s not in rocm_sources ]
510
502
511
- use_cutlass = use_cuda and not IS_WINDOWS
503
+ use_cutlass = False
512
504
cutlass_90a_sources = None
513
505
cutlass_100a_sources = None
514
- cutlass_120a_sources = None
515
506
build_for_sm90a = False
516
507
build_for_sm100a = False
517
- build_for_sm120a = False
518
-
519
- if use_cutlass :
508
+ if use_cuda and not IS_WINDOWS :
509
+ use_cutlass = True
520
510
cutlass_dir = os .path .join (third_party_path , "cutlass" )
521
511
cutlass_include_dir = os .path .join (cutlass_dir , "include" )
522
512
cutlass_tools_include_dir = os .path .join (
523
513
cutlass_dir , "tools" , "util" , "include"
524
514
)
525
515
cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
526
-
516
+ if use_cutlass :
527
517
extra_compile_args ["nvcc" ].extend (
528
518
[
529
519
"-DTORCHAO_USE_CUTLASS" ,
@@ -543,7 +533,7 @@ def get_extensions():
543
533
]
544
534
)
545
535
546
- build_for_sm90a , build_for_sm100a , build_for_sm120a = get_cutlass_build_flags ()
536
+ build_for_sm90a , build_for_sm100a = get_cutlass_build_flags ()
547
537
# Define sm90a sources
548
538
cutlass_90a_sources = [
549
539
os .path .join (
@@ -567,40 +557,40 @@ def get_extensions():
567
557
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
568
558
)
569
559
)
570
- sources = remove_items (sources , cutlass_90a_sources )
560
+ # Always remove sm90a sources from main sources
561
+ sources = [s for s in sources if s not in cutlass_90a_sources ]
571
562
572
563
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
573
564
cutlass_100a_sources = [
574
565
os .path .join (
575
566
extensions_cuda_dir ,
576
567
"mx_kernels" ,
577
- "mx_fp_cutlass_kernels_sm100a .cu" ,
568
+ "mx_fp_cutlass_kernels .cu" ,
578
569
),
579
570
]
580
- sources = remove_items (sources , cutlass_100a_sources )
581
-
582
- # Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
583
- cutlass_120a_sources = [
584
- os .path .join (
585
- extensions_cuda_dir ,
586
- "mx_kernels" ,
587
- "mx_fp_cutlass_kernels_sm120a.cu" ,
588
- ),
571
+ # Remove from main sources to prevent compilation with other architectures
572
+ sources = [
573
+ s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
589
574
]
590
- sources = remove_items (sources , cutlass_120a_sources )
591
575
592
576
else :
593
- # Remove CUTLASS-based kernels from the sources list. An assumption is that
594
- # these files will have "cutlass" in its name.
577
+ # Remove CUTLASS-based kernels from the sources list. An
578
+ # assumption is that these files will have "cutlass" in its
579
+ # name.
595
580
cutlass_sources = list (
596
581
glob .glob (
597
582
os .path .join (extensions_cuda_dir , "**/*cutlass*.cu" ), recursive = True
598
583
)
599
584
)
600
- sources = remove_items ( sources , cutlass_sources )
585
+ sources = [ s for s in sources if s not in cutlass_sources ]
601
586
602
587
ext_modules = []
603
588
if len (sources ) > 0 :
589
+ # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
590
+ sources = [
591
+ s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
592
+ ]
593
+
604
594
ext_modules .append (
605
595
extension (
606
596
"torchao._C" ,
@@ -653,27 +643,6 @@ def get_extensions():
653
643
)
654
644
)
655
645
656
- # Only build the cutlass_120a extension if sm120a is in the architecture flags
657
- if (
658
- cutlass_120a_sources is not None
659
- and len (cutlass_120a_sources ) > 0
660
- and build_for_sm120a
661
- ):
662
- cutlass_120a_extra_compile_args = copy .deepcopy (extra_compile_args )
663
- # Only use sm120a architecture for these sources, ignoring cuda_arch_flags
664
- cutlass_120a_extra_compile_args ["nvcc" ].append (
665
- "-gencode=arch=compute_120a,code=sm_120a"
666
- )
667
- ext_modules .append (
668
- extension (
669
- "torchao._C_cutlass_120a" ,
670
- cutlass_120a_sources ,
671
- py_limited_api = True ,
672
- extra_compile_args = cutlass_120a_extra_compile_args ,
673
- extra_link_args = extra_link_args ,
674
- )
675
- )
676
-
677
646
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
678
647
if build_macos_arm_auto or os .getenv ("BUILD_TORCHAO_EXPERIMENTAL" ) == "1" :
679
648
build_options = BuildOptions ()
0 commit comments