Skip to content

Commit 5b6af5e

Browse files
authored
[backends] Add Quack (#289)
1 parent 4ec37bd commit 5b6af5e

File tree

7 files changed

+61
-0
lines changed

7 files changed

+61
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@
1919
[submodule "submodules/aiter"]
2020
path = submodules/aiter
2121
url = https://github.com/ROCm/aiter.git
22+
[submodule "submodules/quack"]
23+
path = submodules/quack
24+
url = https://github.com/Dao-AILab/quack.git

install.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ def setup_hip(args: argparse.Namespace):
198198
if args.liger or args.all:
199199
logger.info("[tritonbench] installing liger-kernels...")
200200
install_liger()
201+
if args.quack or args.all:
202+
logger.info("[tritonbench] installing quack...")
203+
from tools.quack.install import install_quack
204+
205+
install_quack()
201206
if args.xformers:
202207
logger.info("[tritonbench] installing xformers...")
203208
from tools.xformers.install import install_xformers

submodules/quack

Submodule quack added at a42fef7

tools/quack/install.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
import subprocess
3+
4+
from pathlib import Path
5+
6+
7+
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
8+
CURRENT_DIR = Path(os.path.abspath(__file__)).parent
9+
QUACK_PATH = REPO_PATH.joinpath("submodules", "quack")
10+
11+
12+
def install_quack():
13+
cmd = ["pip", "install", "-e", "."]
14+
subprocess.check_call(cmd, cwd=QUACK_PATH)

tritonbench/operators/rms_norm/operator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
except ModuleNotFoundError:
2424
HAS_AITER = False
2525

26+
try:
27+
from .quack import QuackRMSNorm
28+
except ModuleNotFoundError:
29+
QuackRMSNorm = None
30+
2631

2732
# Reference: https://github.com/linkedin/Liger-Kernel/
2833
# blob/main/benchmark/scripts/benchmark_rms_norm.py
@@ -72,6 +77,11 @@ def liger_rms(self, H, input) -> Callable:
7277
self.liger_rms_op = LigerRMSNorm(hidden_size=H, eps=self.eps).to(self.device)
7378
return lambda: self.liger_rms_op(input)
7479

80+
@register_benchmark(enabled=QuackRMSNorm)
81+
def quack_rms(self, H, input) -> Callable:
82+
self.quack_rms_op = QuackRMSNorm(hidden_size=H, eps=self.eps).to(self.device)
83+
return lambda: self.quack_rms_op(input)
84+
7585
@register_benchmark()
7686
def inductor_rms(self, H, input) -> Callable:
7787
if self.llama_rms_op is None:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
from quack import rmsnorm as quack_rmsnorm
4+
5+
6+
class QuackRMSNorm(torch.nn.Module):
7+
def __init__(self, hidden_size, eps=1e-6):
8+
"""
9+
AITerRMSNorm
10+
"""
11+
super().__init__()
12+
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
13+
self.variance_epsilon = eps
14+
15+
def forward(self, hidden_states):
16+
return quack_rmsnorm(hidden_states, self.weight, self.variance_epsilon)

tritonbench/operators/softmax/operator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
register_metric,
1616
)
1717

18+
try:
19+
from quack.softmax import softmax as quack_softmax
20+
21+
HAS_QUACK = True
22+
except ImportError:
23+
HAS_QUACK = False
24+
1825

1926
class Operator(BenchmarkOperator):
2027
is_compute_bound = False
@@ -105,6 +112,11 @@ def _inner():
105112

106113
return _inner
107114

115+
@register_benchmark(enabled=HAS_QUACK)
116+
def quack(self, x):
117+
inner = lambda: quack_softmax(x)
118+
return inner
119+
108120
def get_input_iter(self):
109121
M = 4096
110122
shapes = [(M, 128 * i) for i in range(2, 100)]

0 commit comments

Comments
 (0)