Skip to content

Commit 06e69f6

Browse files
authored
support W4A8 Marlin kernel (#1113)
support Marlin W4A8 kernel
1 parent 9a9ea25 commit 06e69f6

File tree

16 files changed

+2832
-9
lines changed

16 files changed

+2832
-9
lines changed

benchmarks/benchmark_marlin_qqq.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import pandas as pd
3+
from torchao.utils import benchmark_torch_function_in_microseconds
4+
from torchao.ops import marlin_qqq_gemm
5+
from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq
6+
from tqdm import tqdm
7+
8+
9+
def get_problem(m, n, k, groupsize=-1):
10+
if groupsize == -1:
11+
groupsize = k
12+
dev = torch.device("cuda")
13+
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
14+
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)
15+
16+
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
17+
B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev)
18+
s_tok = torch.ones((m, 1), dtype=torch.float, device=dev)
19+
if groupsize == k:
20+
s_group = torch.tensor([], dtype=torch.half, device=dev)
21+
else:
22+
s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev)
23+
s_channel = torch.ones((1, n), dtype=torch.float, device=dev)
24+
B, s_group, s_channel = pack_to_marlin_qqq(
25+
B, s_group, s_channel, num_bits=4, group_size=group_size
26+
)
27+
qqq_workspace = marlin_qqq_workspace(n)
28+
return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace
29+
30+
31+
def benchmark(m: int, k: int, n: int, group_size: int):
32+
A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem(
33+
m, n, k, group_size
34+
)
35+
36+
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
37+
marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds(
38+
marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k
39+
)
40+
41+
return {
42+
"m": m,
43+
"k": k,
44+
"n": n,
45+
"group_size": group_size,
46+
"fp16_latency (ms)": fp16_time,
47+
"marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time,
48+
"speedup (d/s)": fp16_time / marlin_qqq_w4a8_time,
49+
}
50+
51+
52+
if __name__ == "__main__":
53+
k_vals = (8192, 8192, 8192, 28672)
54+
n_vals = (8192, 10240, 57344, 8192)
55+
56+
results = []
57+
for group_size in tqdm([-1, 128]):
58+
for m in tqdm([1 << i for i in range(10)]):
59+
for n, k in zip(n_vals, k_vals):
60+
results.append(benchmark(m, k, n, group_size))
61+
62+
df = pd.DataFrame(results)
63+
df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False)
64+
print(df.to_markdown(index=False))

test/quantization/test_marlin_qqq.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import copy
2+
3+
import pytest
4+
import torch
5+
from torch import nn
6+
from torch.testing._internal.common_utils import TestCase, run_tests
7+
8+
from torchao.dtypes import MarlinQQQLayout
9+
from torchao.quantization.marlin_qqq import (
10+
pack_to_marlin_qqq,
11+
unpack_from_marlin_qqq,
12+
)
13+
from torchao.quantization.quant_api import (
14+
int8_dynamic_activation_int4_weight,
15+
quantize_,
16+
)
17+
from torchao.quantization.quant_primitives import (
18+
MappingType,
19+
choose_qparams_and_quantize_affine_qqq,
20+
)
21+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
22+
23+
24+
class MarlinQQQ(TestCase):
25+
def setUp(self):
26+
super().setUp()
27+
torch.manual_seed(0)
28+
29+
self.input = torch.randn((64, 32, 8192), dtype=torch.float16, device="cuda")
30+
self.model = (
31+
nn.Sequential(
32+
nn.Linear(8192, 21504),
33+
nn.Linear(21504, 8192),
34+
nn.ReLU(),
35+
nn.Linear(8192, 21504),
36+
nn.Linear(21504, 8192),
37+
)
38+
.half()
39+
.cuda()
40+
)
41+
42+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
43+
def test_marlin_qqq(self):
44+
output_ref = self.model(self.input)
45+
for group_size in [-1, 128]:
46+
modelq = copy.deepcopy(self.model)
47+
quantize_(
48+
modelq,
49+
int8_dynamic_activation_int4_weight(
50+
group_size=group_size,
51+
mapping_type=MappingType.SYMMETRIC,
52+
act_mapping_type=MappingType.SYMMETRIC,
53+
layout=MarlinQQQLayout(),
54+
),
55+
)
56+
output = modelq(self.input)
57+
58+
assert torch.allclose(
59+
output, output_ref, atol=1e-1
60+
), "Results are not close"
61+
62+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
63+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
64+
def test_marlin_qqq_compile(self):
65+
model_copy = copy.deepcopy(self.model)
66+
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
67+
output_ref = model_copy(self.input)
68+
69+
for group_size in [-1, 128]:
70+
modelq = copy.deepcopy(self.model)
71+
quantize_(
72+
modelq,
73+
int8_dynamic_activation_int4_weight(
74+
group_size=group_size,
75+
mapping_type=MappingType.SYMMETRIC,
76+
act_mapping_type=MappingType.SYMMETRIC,
77+
layout=MarlinQQQLayout(),
78+
),
79+
)
80+
modelq.forward = torch.compile(modelq.forward, fullgraph=True)
81+
output = modelq(self.input)
82+
83+
assert torch.allclose(
84+
output, output_ref, atol=1e-1
85+
), "Results are not close"
86+
87+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
88+
def test_pack_unpack_equivalence(self):
89+
num_bits = 4
90+
shape = (11008, 4096)
91+
92+
w = torch.rand(shape, dtype=torch.float16, device="cuda")
93+
94+
for group_size in [-1, 128]:
95+
# Quantize weights
96+
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
97+
w, num_bits, group_size
98+
)
99+
100+
q_w = q_w.t()
101+
s_group = s_group.t()
102+
s_channel = s_channel.t()
103+
104+
# Test pack/unpack equivalence
105+
q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq(
106+
q_w, s_group, s_channel, num_bits, group_size
107+
)
108+
unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq(
109+
q_w_comp,
110+
packed_s_group,
111+
packed_s_channel,
112+
q_w.shape,
113+
num_bits,
114+
group_size,
115+
)
116+
117+
assert torch.equal(
118+
q_w, unpacked_q_w
119+
), "Unpacked weights do not match original weights"
120+
assert torch.equal(
121+
s_channel, unpacked_s_channel
122+
), "Unpacked s_channel do not match original s_channel"
123+
assert torch.equal(
124+
s_group, unpacked_s_group
125+
), "Unpacked s_group do not match original s_group"
126+
127+
128+
if __name__ == "__main__":
129+
run_tests()

test/test_ops.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
1414
from torchao.dtypes.floatx import from_scaled_tc_floatx
1515
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
16+
from torchao.quantization.marlin_qqq import (
17+
marlin_qqq_workspace,
18+
pack_to_marlin_qqq,
19+
)
20+
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
1621
import pytest
1722

1823
if is_fbcode():
@@ -426,5 +431,109 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
426431
)
427432

428433

434+
MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
435+
MARLIN_QQQ_K_CHUNKS = [128]
436+
MARLIN_QQQ_N_CHUNKS = [64, 128, 256]
437+
MNK_FACTORS = [
438+
(1, 1, 1),
439+
(1, 4, 8),
440+
(1, 7, 5),
441+
(13, 17, 67),
442+
(26, 37, 13),
443+
(67, 13, 11),
444+
]
445+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
446+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
447+
448+
MARLIN_TEST_PARAMS = list(
449+
itertools.product(
450+
MARLIN_QQQ_BATCH_SIZE,
451+
MARLIN_QQQ_K_CHUNKS,
452+
MARLIN_QQQ_N_CHUNKS,
453+
MARLIN_QQQ_SUPPORTED_NUM_BITS,
454+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
455+
MNK_FACTORS,
456+
)
457+
)
458+
459+
460+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
461+
@pytest.mark.parametrize(
462+
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
463+
MARLIN_TEST_PARAMS,
464+
ids=str,
465+
)
466+
def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
467+
int8_traits = torch.iinfo(torch.int8)
468+
m_factor, n_factor, k_factor = mnk_factors
469+
470+
size_m = m_factor
471+
size_k = k_chunk * k_factor
472+
size_n = n_chunk * n_factor
473+
474+
a_input = torch.randn(
475+
(batch_size, size_m, size_k), dtype=torch.float16, device="cuda"
476+
)
477+
b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda")
478+
479+
# Reshape input into 2D tensor
480+
input_2d = a_input.view(-1, a_input.shape[-1])
481+
a_input_in, a_input_out = input_2d.shape
482+
483+
# Quantize activations
484+
s_a = (
485+
input_2d.abs()
486+
.max(dim=-1, keepdim=True)[0]
487+
.div(int8_traits.max)
488+
.to(torch.float32)
489+
)
490+
q_a = (
491+
(input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
492+
)
493+
494+
# Quantize weights
495+
q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq(
496+
b_weight, num_bits, group_size
497+
)
498+
q_w = q_w.t()
499+
s_group = s_group.t()
500+
s_channel = s_channel.t()
501+
w_ref = w_ref.t()
502+
marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq(
503+
q_w, s_group, s_channel, num_bits, group_size
504+
)
505+
506+
workspace = marlin_qqq_workspace(size_n)
507+
508+
# Obtains reference output
509+
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
510+
output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,))
511+
512+
fn_inputs = (
513+
q_a,
514+
marlin_qqq_q_w,
515+
s_a,
516+
marlin_qqq_s_channel,
517+
marlin_qqq_s_group,
518+
workspace,
519+
a_input_in,
520+
size_n,
521+
a_input_out,
522+
)
523+
output = torchao.ops.marlin_qqq_gemm(*fn_inputs)
524+
output = output.reshape(a_input.shape[:-1] + (size_n,))
525+
526+
max_diff = compute_max_diff(output, output_ref)
527+
assert max_diff < 0.04
528+
529+
# Performs opcheck
530+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
531+
opcheck(
532+
torch.ops.torchao.marlin_qqq_gemm,
533+
fn_inputs,
534+
test_utils=test_utils,
535+
)
536+
537+
429538
if __name__ == "__main__":
430539
run_tests()

torchao/_models/llama/generate.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch._dynamo.config
1515
import torch._inductor.config
1616
from torchao.utils import get_model_size_in_bytes
17+
from torchao.quantization.quant_primitives import MappingType
1718
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1819

1920
def device_sync(device):
@@ -211,6 +212,7 @@ def main(
211212
int8_weight_only,
212213
int8_dynamic_activation_int8_weight,
213214
int4_weight_only,
215+
int8_dynamic_activation_int4_weight,
214216
fpx_weight_only,
215217
uintx_weight_only,
216218
autoquant,
@@ -235,8 +237,20 @@ def main(
235237
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
236238
quantize_(model, int4_weight_only(group_size=group_size))
237239
if "marlin" in quantization:
238-
from torchao.dtypes import MarlinSparseLayout
239-
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
240+
if "qqq" in quantization:
241+
from torchao.dtypes import MarlinQQQLayout
242+
quantize_(
243+
model,
244+
int8_dynamic_activation_int4_weight(
245+
group_size=128,
246+
mapping_type=MappingType.SYMMETRIC,
247+
act_mapping_type=MappingType.SYMMETRIC,
248+
layout=MarlinQQQLayout(),
249+
),
250+
)
251+
else:
252+
from torchao.dtypes import MarlinSparseLayout
253+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
240254
if "fp6" in quantization:
241255
quantize_(model, fpx_weight_only(3, 2))
242256
if "embed-int8wo" in quantization:
@@ -474,7 +488,7 @@ def callback(x):
474488
help=(
475489
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
476490
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
477-
+'embed-int8wo'
491+
+'embed-int8wo, marlin_qqq'
478492
)
479493
)
480494
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")

0 commit comments

Comments
 (0)