diff --git a/README.md b/README.md index 2210aa22..4133b587 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ We depend on the following projects as a source of customized Triton or CUTLASS * (CUDA, HIP) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders) * (CUDA, HIP) [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) * (CUDA, HIP) [tilelang](https://github.com/tile-ai/tilelang) +* (CUDA) [CUTLASS Python DSL](https://github.com/NVIDIA/cutlass) * (CUDA) [xformers](https://github.com/facebookresearch/xformers) * (CUDA) [flash-attention](https://github.com/Dao-AILab/flash-attention) * (CUDA) [FBGEMM](https://github.com/pytorch/FBGEMM) diff --git a/install.py b/install.py index 922f8fde..a7963c14 100644 --- a/install.py +++ b/install.py @@ -14,6 +14,14 @@ pip_install_requirements, ) +from tools.flash_attn.install import install_fa3 +from tools.cutlass.install import install_cutlass +from tools.tk.install import install_tk +from tools.tilelang.install import install_tile +from tools.xformers.install import install_xformers +from tools.aiter.install import install_aiter + + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -128,6 +136,9 @@ def setup_hip(args: argparse.Namespace): parser.add_argument( "--fa3", action="store_true", help="Install optional flash_attention 3 kernels" ) + parser.add_argument( + "--cutlass", action="store_true", help="Install optional CUTLASS Python DSL" + ) parser.add_argument("--jax", action="store_true", help="Install jax nightly") parser.add_argument("--tk", action="store_true", help="Install ThunderKittens") parser.add_argument("--liger", action="store_true", help="Install Liger-kernel") @@ -157,8 +168,6 @@ def setup_hip(args: argparse.Namespace): if args.fa3 or args.all: # we need to install fa3 above all other dependencies logger.info("[tritonbench] installing fa3...") - from tools.flash_attn.install import install_fa3 - install_fa3() if args.fbgemm or args.fbgemm_all or args.all: logger.info("[tritonbench] installing FBGEMM...") @@ -172,25 +181,20 @@ def setup_hip(args: argparse.Namespace): install_jax() if args.tk or args.all: logger.info("[tritonbench] installing thunderkittens...") - from tools.tk.install import install_tk - install_tk() + if args.cutlass or args.all: + logger.info("[tritonbench] installing cutlass Python DSL...") + install_cutlass() if args.tile: logger.info("[tritonbench] installing tilelang...") - from tools.tilelang.install import install_tile - install_tile() if args.liger or args.all: logger.info("[tritonbench] installing liger-kernels...") install_liger() if args.xformers: logger.info("[tritonbench] installing xformers...") - from tools.xformers.install import install_xformers - install_xformers() if args.aiter and is_hip(): logger.info("[tritonbench] installing aiter...") - from tools.aiter.install import install_aiter - install_aiter() logger.info("[tritonbench] installation complete!") diff --git a/submodules/cutlass b/submodules/cutlass index ad7b2f5e..f115c3f8 160000 --- a/submodules/cutlass +++ b/submodules/cutlass @@ -1 +1 @@ -Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e +Subproject commit f115c3f85467d5d9619119d1dbeb9c03c3d73864 diff --git a/tools/cutlass/install.py b/tools/cutlass/install.py new file mode 100644 index 00000000..cd13f423 --- /dev/null +++ b/tools/cutlass/install.py @@ -0,0 +1,11 @@ +import os +import subprocess + +from pathlib import Path + +REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent +CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass") + +def install_cutlass(): + cmd = ["pip", "install", "-e", "."] + subprocess.check_call(cmd, cwd=str(CUTLASS_PATH.resolve()))