From 82129f3d0042e6da4bb3ca6f874227a783644698 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 13 May 2025 17:07:56 -0400 Subject: [PATCH 1/2] install cutlass python dsl --- README.md | 1 + install.py | 21 +++++++++++---------- submodules/cutlass | 2 +- tools/cutlass/install.py | 11 +++++++++++ 4 files changed, 24 insertions(+), 11 deletions(-) create mode 100644 tools/cutlass/install.py diff --git a/README.md b/README.md index 2210aa22..4133b587 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ We depend on the following projects as a source of customized Triton or CUTLASS * (CUDA, HIP) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders) * (CUDA, HIP) [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) * (CUDA, HIP) [tilelang](https://github.com/tile-ai/tilelang) +* (CUDA) [CUTLASS Python DSL](https://github.com/NVIDIA/cutlass) * (CUDA) [xformers](https://github.com/facebookresearch/xformers) * (CUDA) [flash-attention](https://github.com/Dao-AILab/flash-attention) * (CUDA) [FBGEMM](https://github.com/pytorch/FBGEMM) diff --git a/install.py b/install.py index 922f8fde..2ae2bccf 100644 --- a/install.py +++ b/install.py @@ -14,6 +14,14 @@ pip_install_requirements, ) +from tools.flash_attn.install import install_fa3 +from tools.cutlass.install import install_cutlass +from tools.tk.install import install_tk +from tools.tilelang.install import install_tile +from tools.xformers.install import install_xformers +from tools.aiter.install import install_aiter + + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -157,8 +165,6 @@ def setup_hip(args: argparse.Namespace): if args.fa3 or args.all: # we need to install fa3 above all other dependencies logger.info("[tritonbench] installing fa3...") - from tools.flash_attn.install import install_fa3 - install_fa3() if args.fbgemm or args.fbgemm_all or args.all: logger.info("[tritonbench] installing FBGEMM...") @@ -172,25 +178,20 @@ def setup_hip(args: argparse.Namespace): install_jax() if args.tk or args.all: logger.info("[tritonbench] installing thunderkittens...") - from tools.tk.install import install_tk - install_tk() + if args.cutlass or args.all: + logger.info("[tritonbench] installing cutlass Python DSL...") + install_cutlass() if args.tile: logger.info("[tritonbench] installing tilelang...") - from tools.tilelang.install import install_tile - install_tile() if args.liger or args.all: logger.info("[tritonbench] installing liger-kernels...") install_liger() if args.xformers: logger.info("[tritonbench] installing xformers...") - from tools.xformers.install import install_xformers - install_xformers() if args.aiter and is_hip(): logger.info("[tritonbench] installing aiter...") - from tools.aiter.install import install_aiter - install_aiter() logger.info("[tritonbench] installation complete!") diff --git a/submodules/cutlass b/submodules/cutlass index ad7b2f5e..f115c3f8 160000 --- a/submodules/cutlass +++ b/submodules/cutlass @@ -1 +1 @@ -Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e +Subproject commit f115c3f85467d5d9619119d1dbeb9c03c3d73864 diff --git a/tools/cutlass/install.py b/tools/cutlass/install.py new file mode 100644 index 00000000..6cc80b52 --- /dev/null +++ b/tools/cutlass/install.py @@ -0,0 +1,11 @@ +import os +import subprocess + +from pathlib import Path + +REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent +CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass", "python") + +def install_cutlass(): + cmd = ["pip", "install", "-e", "."] + subprocess.check_call(cmd, cwd=str(CUTLASS_PATH.resolve())) From 1b390223be8eb0ea94b3a1521e988678afd7705e Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 13 May 2025 14:11:22 -0700 Subject: [PATCH 2/2] install cutlass dsl --- install.py | 3 +++ tools/cutlass/install.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/install.py b/install.py index 2ae2bccf..a7963c14 100644 --- a/install.py +++ b/install.py @@ -136,6 +136,9 @@ def setup_hip(args: argparse.Namespace): parser.add_argument( "--fa3", action="store_true", help="Install optional flash_attention 3 kernels" ) + parser.add_argument( + "--cutlass", action="store_true", help="Install optional CUTLASS Python DSL" + ) parser.add_argument("--jax", action="store_true", help="Install jax nightly") parser.add_argument("--tk", action="store_true", help="Install ThunderKittens") parser.add_argument("--liger", action="store_true", help="Install Liger-kernel") diff --git a/tools/cutlass/install.py b/tools/cutlass/install.py index 6cc80b52..cd13f423 100644 --- a/tools/cutlass/install.py +++ b/tools/cutlass/install.py @@ -4,7 +4,7 @@ from pathlib import Path REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent -CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass", "python") +CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass") def install_cutlass(): cmd = ["pip", "install", "-e", "."]