Skip to content

[cutlass] Add CUTLASS Python DSL #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ We depend on the following projects as a source of customized Triton or CUTLASS
* (CUDA, HIP) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders)
* (CUDA, HIP) [Liger-Kernel](https://github.com/linkedin/Liger-Kernel)
* (CUDA, HIP) [tilelang](https://github.com/tile-ai/tilelang)
* (CUDA) [CUTLASS Python DSL](https://github.com/NVIDIA/cutlass)
* (CUDA) [xformers](https://github.com/facebookresearch/xformers)
* (CUDA) [flash-attention](https://github.com/Dao-AILab/flash-attention)
* (CUDA) [FBGEMM](https://github.com/pytorch/FBGEMM)
Expand Down
24 changes: 14 additions & 10 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
pip_install_requirements,
)

from tools.flash_attn.install import install_fa3
from tools.cutlass.install import install_cutlass
from tools.tk.install import install_tk
from tools.tilelang.install import install_tile
from tools.xformers.install import install_xformers
from tools.aiter.install import install_aiter


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,6 +136,9 @@ def setup_hip(args: argparse.Namespace):
parser.add_argument(
"--fa3", action="store_true", help="Install optional flash_attention 3 kernels"
)
parser.add_argument(
"--cutlass", action="store_true", help="Install optional CUTLASS Python DSL"
)
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
Expand Down Expand Up @@ -157,8 +168,6 @@ def setup_hip(args: argparse.Namespace):
if args.fa3 or args.all:
# we need to install fa3 above all other dependencies
logger.info("[tritonbench] installing fa3...")
from tools.flash_attn.install import install_fa3

install_fa3()
if args.fbgemm or args.fbgemm_all or args.all:
logger.info("[tritonbench] installing FBGEMM...")
Expand All @@ -172,25 +181,20 @@ def setup_hip(args: argparse.Namespace):
install_jax()
if args.tk or args.all:
logger.info("[tritonbench] installing thunderkittens...")
from tools.tk.install import install_tk

install_tk()
if args.cutlass or args.all:
logger.info("[tritonbench] installing cutlass Python DSL...")
install_cutlass()
if args.tile:
logger.info("[tritonbench] installing tilelang...")
from tools.tilelang.install import install_tile

install_tile()
if args.liger or args.all:
logger.info("[tritonbench] installing liger-kernels...")
install_liger()
if args.xformers:
logger.info("[tritonbench] installing xformers...")
from tools.xformers.install import install_xformers

install_xformers()
if args.aiter and is_hip():
logger.info("[tritonbench] installing aiter...")
from tools.aiter.install import install_aiter

install_aiter()
logger.info("[tritonbench] installation complete!")
2 changes: 1 addition & 1 deletion submodules/cutlass
Submodule cutlass updated 299 files
11 changes: 11 additions & 0 deletions tools/cutlass/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os
import subprocess

from pathlib import Path

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass")

def install_cutlass():
cmd = ["pip", "install", "-e", "."]
subprocess.check_call(cmd, cwd=str(CUTLASS_PATH.resolve()))
Loading