Skip to content

Commit 8e8472c

Browse files
authored
Add quantized q @ k test for intented used in quantized attention
Differential Revision: D71370604 Pull Request resolved: #2006
1 parent b49f23c commit 8e8472c

File tree

3 files changed

+334
-0
lines changed

3 files changed

+334
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <gtest/gtest.h>
1313
#include <torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h>
1414
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
15+
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h>
1516

1617
float kTol = 0.0001;
1718

@@ -411,4 +412,101 @@ INSTANTIATE_TEST_SUITE_P(
411412
FP32A_QuantizedB_FP32C_Test,
412413
::testing::Values(0.0, 1.0, 2.69));
413414

415+
static void test_8bit_per_token_q_at_k_matmul_attention(
416+
int b,
417+
int s_q,
418+
int s_k,
419+
int h,
420+
int d,
421+
bool transpose = true) {
422+
auto test_case = torchao::
423+
channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case::
424+
generate(b, s_q, s_k, h, d, transpose);
425+
426+
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
427+
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;
428+
429+
size_t q_b_stride = test_case.b_q_stride;
430+
size_t q_h_stride = test_case.h_q_stride;
431+
size_t q_s_q_stride = test_case.s_q_stride;
432+
size_t q_scale_zp_b_stride = test_case.b_q_qparams_stride;
433+
size_t q_scale_zp_h_stride = test_case.h_q_qparams_stride;
434+
size_t q_scale_zp_s_stride = test_case.s_q_qparams_stride;
435+
436+
size_t k_b_stride = test_case.b_k_stride;
437+
size_t k_h_stride = test_case.h_k_stride;
438+
size_t k_s_k_stride = test_case.s_k_stride;
439+
size_t k_scale_zp_b_stride = test_case.b_k_qparams_stride;
440+
size_t k_scale_zp_h_stride = test_case.h_k_qparams_stride;
441+
size_t k_scale_zp_s_stride = test_case.s_k_qparams_stride;
442+
443+
std::vector<float> output(b * h * s_q * s_k);
444+
size_t output_b_stride = h * s_q * s_k;
445+
size_t output_h_stride = s_q * s_k;
446+
size_t output_s_q_stride = s_k;
447+
448+
for (int b_idx = 0; b_idx < b; b_idx++) {
449+
for (int h_idx = 0; h_idx < h; h_idx++) {
450+
kernel<true, true, false, true>(
451+
s_q,
452+
s_k,
453+
d,
454+
test_case.q_qvals.data() + b_idx * q_b_stride + h_idx * q_h_stride,
455+
q_s_q_stride /*lhs_stride_m*/,
456+
test_case.k_qvals.data() + b_idx * k_b_stride + h_idx * k_h_stride,
457+
k_s_k_stride /*rhs_stride_n*/,
458+
output.data() + b_idx * output_b_stride + h_idx * output_h_stride,
459+
output_s_q_stride /*out_stride_n*/,
460+
test_case.q_zeros.data() + b_idx * q_scale_zp_b_stride +
461+
h_idx * q_scale_zp_h_stride,
462+
test_case.k_zeros.data() + b_idx * k_scale_zp_b_stride +
463+
h_idx * k_scale_zp_h_stride,
464+
test_case.q_scales.data() + b_idx * q_scale_zp_b_stride +
465+
h_idx * q_scale_zp_h_stride,
466+
test_case.k_scales.data() + b_idx * k_scale_zp_b_stride +
467+
h_idx * k_scale_zp_h_stride,
468+
q_scale_zp_s_stride /*lhs qparams stride*/,
469+
k_scale_zp_s_stride /*rhs qparams stride*/);
470+
}
471+
}
472+
473+
for (int i = 0; i < b * h * s_q * s_k; i++) {
474+
EXPECT_NEAR(output[i], test_case.expected_output[i], kTol);
475+
}
476+
}
477+
478+
TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) {
479+
test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16);
480+
}
481+
482+
TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) {
483+
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33);
484+
}
485+
486+
TEST(
487+
test_8bit_per_token_q_at_k_matmul_attention,
488+
PrimeHeadsAndHeadDimDiffSqSk) {
489+
test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33);
490+
}
491+
492+
TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndSmallHeadDim) {
493+
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3);
494+
}
495+
496+
TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicNoTransposed) {
497+
test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16, false);
498+
}
499+
500+
TEST(
501+
test_8bit_per_token_q_at_k_matmul_attention,
502+
PrimeHeadsAndHeadDimDiffSqSkNoTranspose) {
503+
test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33, false);
504+
}
505+
506+
TEST(
507+
test_8bit_per_token_q_at_k_matmul_attention,
508+
PrimeHeadsAndSmallHeadDimNoTranspose) {
509+
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false);
510+
}
511+
414512
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ inline float get_float_from_bf16(uint16_t bf16) {
8686

8787
namespace test_utils {
8888
auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false);
89+
8990
auto generate_per_token_quantized_tensor(int m, int n, bool transposed) {
9091
auto activations = get_random_vector(m * n, -1.0, 1.0);
9192
auto activation_qvals = std::vector<int8_t>(m * n, 0);
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) || defined(__ARM_NEON)
10+
11+
#include <torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
14+
#include <cassert>
15+
#include <functional>
16+
#include <random>
17+
#include <vector>
18+
19+
namespace torchao {
20+
struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case {
21+
int b;
22+
int s_q;
23+
int s_k;
24+
int h;
25+
int d;
26+
bool tranposed;
27+
28+
size_t b_q_stride;
29+
size_t h_q_stride;
30+
size_t s_q_stride;
31+
32+
size_t b_k_stride;
33+
size_t h_k_stride;
34+
size_t s_k_stride;
35+
36+
size_t b_q_qparams_stride;
37+
size_t h_q_qparams_stride;
38+
size_t s_q_qparams_stride;
39+
40+
size_t b_k_qparams_stride;
41+
size_t h_k_qparams_stride;
42+
size_t s_k_qparams_stride;
43+
44+
std::vector<float> expected_output;
45+
46+
std::vector<float> q;
47+
std::vector<int8_t> q_qvals;
48+
std::vector<float> q_scales;
49+
std::vector<int8_t> q_zeros;
50+
51+
std::vector<float> k;
52+
std::vector<int8_t> k_qvals;
53+
std::vector<float> k_scales;
54+
std::vector<int8_t> k_zeros;
55+
56+
channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case(
57+
int b_,
58+
int s_q_,
59+
int s_k_,
60+
int h_,
61+
int d_,
62+
int transposed_,
63+
size_t b_q_stride_,
64+
size_t h_q_stride_,
65+
size_t s_q_stride_,
66+
size_t b_k_stride_,
67+
size_t h_k_stride_,
68+
size_t s_k_stride_,
69+
size_t b_q_qparams_stride_,
70+
size_t h_q_qparams_stride_,
71+
size_t s_q_qparams_stride_,
72+
size_t b_k_qparams_stride_,
73+
size_t h_k_qparams_stride_,
74+
size_t s_k_qparams_stride_,
75+
std::vector<float> expected_output_,
76+
std::vector<float> q_,
77+
std::vector<int8_t> q_qvals_,
78+
std::vector<float> q_scales_,
79+
std::vector<int8_t> q_zeros_,
80+
std::vector<float> k_,
81+
std::vector<int8_t> k_qvals_,
82+
std::vector<float> k_scales_,
83+
std::vector<int8_t> k_zeros_)
84+
: b(b_),
85+
s_q(s_q_),
86+
s_k(s_k_),
87+
h(h_),
88+
d(d_),
89+
tranposed(transposed_),
90+
b_q_stride(b_q_stride_),
91+
h_q_stride(h_q_stride_),
92+
s_q_stride(s_q_stride_),
93+
b_k_stride(b_k_stride_),
94+
h_k_stride(h_k_stride_),
95+
s_k_stride(s_k_stride_),
96+
b_q_qparams_stride(b_q_qparams_stride_),
97+
h_q_qparams_stride(h_q_qparams_stride_),
98+
s_q_qparams_stride(s_q_qparams_stride_),
99+
b_k_qparams_stride(b_k_qparams_stride_),
100+
h_k_qparams_stride(h_k_qparams_stride_),
101+
s_k_qparams_stride(s_k_qparams_stride_),
102+
expected_output(expected_output_),
103+
q(q_),
104+
q_qvals(q_qvals_),
105+
q_scales(q_scales_),
106+
q_zeros(q_zeros_),
107+
k(k_),
108+
k_qvals(k_qvals_),
109+
k_scales(k_scales_),
110+
k_zeros(k_zeros_) {
111+
assert(expected_output.size() == b * s_q * h * s_k);
112+
assert(q.size() == b * s_q * h * d);
113+
assert(q_qvals.size() == b * s_q * h * d);
114+
assert(q_scales.size() == b * s_q * h);
115+
assert(q_zeros.size() == b * s_q * h);
116+
assert(k.size() == b * s_k * h * d);
117+
assert(k_qvals.size() == b * s_k * h * d);
118+
assert(k_scales.size() == b * s_k * h);
119+
assert(k_zeros.size() == b * s_k * h);
120+
}
121+
122+
static channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case
123+
generate(int b, int s_q, int s_k, int h, int d, bool transposed = true) {
124+
// Generate activations
125+
auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] =
126+
torchao::test_utils::generate_per_token_quantized_tensor(
127+
b * s_q * h, d);
128+
129+
auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] =
130+
torchao::test_utils::generate_per_token_quantized_tensor(
131+
b * s_k * h, d);
132+
// Above function produces nxk matrix and to produce kxn you need transposed
133+
// = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true
134+
// the shape should be nxk instead of kxn.
135+
136+
size_t b_q_stride = h * s_q * d;
137+
size_t h_q_stride = s_q * d;
138+
size_t s_q_stride = d;
139+
140+
size_t b_k_stride = h * s_k * d;
141+
size_t h_k_stride = s_k * d;
142+
size_t s_k_stride = d;
143+
144+
size_t b_q_qparams_stride = h * s_q;
145+
size_t h_q_qparams_stride = s_q;
146+
size_t s_q_qparams_stride = 1;
147+
148+
size_t b_k_qparams_stride = h * s_k;
149+
size_t h_k_qparams_stride = s_k;
150+
size_t s_k_qparams_stride = 1;
151+
152+
if (!transposed) {
153+
h_q_stride = d;
154+
s_q_stride = h * d;
155+
h_k_stride = d;
156+
s_k_stride = h * d;
157+
158+
s_q_qparams_stride = h;
159+
h_q_qparams_stride = 1;
160+
161+
s_k_qparams_stride = h;
162+
h_k_qparams_stride = 1;
163+
}
164+
165+
// Compute expected output
166+
std::vector<float> expected_output(b * h * s_q * s_k);
167+
size_t b_out_stride = h * s_q * s_k;
168+
size_t h_out_stride = s_q * s_k;
169+
size_t s_q_out_stride = s_k;
170+
171+
for (int b_idx = 0; b_idx < b; b_idx++) {
172+
for (int s_q_idx = 0; s_q_idx < s_q; s_q_idx++) {
173+
for (int h_idx = 0; h_idx < h; h_idx++) {
174+
for (int s_k_idx = 0; s_k_idx < s_k; s_k_idx++) {
175+
float res = 0.0;
176+
for (int d_idx = 0; d_idx < d; d_idx++) {
177+
int lhs_idx = b_idx * b_q_stride + s_q_idx * s_q_stride +
178+
h_idx * h_q_stride + d_idx;
179+
int rhs_idx = b_idx * b_k_stride + s_k_idx * s_k_stride +
180+
h_idx * h_k_stride + d_idx;
181+
int lhs_scales_zp_idx = b_idx * b_q_qparams_stride +
182+
h_idx * h_q_qparams_stride + s_q_idx * s_q_qparams_stride;
183+
int rhs_scales_zp_idx = b_idx * b_k_qparams_stride * h +
184+
h_idx * h_k_qparams_stride + s_k_idx * s_k_qparams_stride;
185+
float lhs_dequant = lhs_scales[lhs_scales_zp_idx] *
186+
(lhs_qvals[lhs_idx] - lhs_zeros[lhs_scales_zp_idx]);
187+
188+
float rhs_dequant = rhs_scales[rhs_scales_zp_idx] *
189+
(rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]);
190+
191+
res += lhs_dequant * rhs_dequant;
192+
}
193+
expected_output
194+
[b_idx * b_out_stride + s_q_idx * s_q_out_stride +
195+
h_idx * h_out_stride + s_k_idx] = res;
196+
}
197+
}
198+
}
199+
}
200+
201+
// Return test case
202+
return channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case(
203+
b,
204+
s_q,
205+
s_k,
206+
h,
207+
d,
208+
transposed,
209+
b_q_stride,
210+
h_q_stride,
211+
s_q_stride,
212+
b_k_stride,
213+
h_k_stride,
214+
s_k_stride,
215+
b_q_qparams_stride,
216+
h_q_qparams_stride,
217+
s_q_qparams_stride,
218+
b_k_qparams_stride,
219+
h_k_qparams_stride,
220+
s_k_qparams_stride,
221+
expected_output,
222+
lhs,
223+
lhs_qvals,
224+
lhs_scales,
225+
lhs_zeros,
226+
rhs,
227+
rhs_qvals,
228+
rhs_scales,
229+
rhs_zeros);
230+
}
231+
};
232+
233+
} // namespace torchao
234+
235+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)