Skip to content

Commit 6987576

Browse files
authored
Add quantized attn_scores @ v test for intented used in quantized attention
Differential Revision: D71370603 Pull Request resolved: #2008
1 parent 620356d commit 6987576

File tree

3 files changed

+261
-0
lines changed

3 files changed

+261
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ struct KernelImpl {
101101
const int rhs_qparams_stride);
102102
};
103103

104+
/*
105+
Document param meaning
106+
rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n.
107+
Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements
108+
in k dim are.
109+
*/
104110
template <>
105111
struct KernelImpl<true, false, false> {
106112
static void run(

torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,4 +509,91 @@ TEST(
509509
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false);
510510
}
511511

512+
static void test_fp32_attn_scores_at_v_matmul_attention(
513+
int b,
514+
int s_attn,
515+
int s_v,
516+
int h,
517+
int d,
518+
bool transpose_v = true) {
519+
auto test_case =
520+
torchao::fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case::generate(
521+
b, s_attn, s_v, h, d, transpose_v);
522+
523+
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
524+
fp32_a_input_channelwise_8bit_b_1x16x4_f32;
525+
526+
size_t attn_b_stride = test_case.b_attn_stride;
527+
size_t attn_h_stride = test_case.h_attn_stride;
528+
size_t attn_s_q_stride = test_case.s_attn_stride;
529+
530+
size_t v_b_stride = test_case.b_v_stride;
531+
size_t v_h_stride = test_case.h_v_stride;
532+
size_t v_s_v_stride = test_case.s_v_stride;
533+
size_t v_scale_zp_b_stride = test_case.b_v_qparams_stride;
534+
size_t v_scale_zp_h_stride = test_case.h_v_qparams_stride;
535+
size_t v_scale_zp_s_stride = test_case.s_v_qparams_stride;
536+
537+
std::vector<float> output(b * s_attn * h * d);
538+
size_t output_b_stride = s_attn * h * d;
539+
size_t output_s_attn_stride = h * d;
540+
size_t output_h_stride = d;
541+
542+
for (int b_idx = 0; b_idx < b; b_idx++) {
543+
for (int h_idx = 0; h_idx < h; h_idx++) {
544+
kernel<true, false, false>(
545+
s_attn,
546+
d,
547+
s_v,
548+
test_case.attn_scores.data() + b_idx * attn_b_stride +
549+
h_idx * attn_h_stride,
550+
attn_s_q_stride /*lhs_stride_m*/,
551+
test_case.v_qvals.data() + b_idx * v_b_stride + h_idx * v_h_stride,
552+
v_s_v_stride /*rhs_stride_n*/,
553+
output.data() + b_idx * output_b_stride + h_idx * output_h_stride,
554+
output_s_attn_stride /*out_stride_n*/,
555+
test_case.v_zeros.data() + b_idx * v_scale_zp_b_stride +
556+
h_idx * v_scale_zp_h_stride,
557+
test_case.v_scales.data() + b_idx * v_scale_zp_b_stride +
558+
h_idx * v_scale_zp_h_stride,
559+
0.0 /*beta*/,
560+
v_scale_zp_s_stride /*rhs qparams stride*/);
561+
}
562+
}
563+
564+
for (int i = 0; i < b * s_attn * h * d; i++) {
565+
EXPECT_NEAR(output[i], test_case.expected_output[i], kTol);
566+
}
567+
}
568+
569+
TEST(test_fp32_attn_scores_at_v_matmul_attention, Basic) {
570+
test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16);
571+
}
572+
573+
TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndHeadDim) {
574+
test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 33);
575+
}
576+
577+
TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDim) {
578+
test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33);
579+
}
580+
581+
TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndSmallHeadDim) {
582+
test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17);
583+
}
584+
585+
TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose) {
586+
test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16, false);
587+
}
588+
589+
TEST(
590+
test_fp32_attn_scores_at_v_matmul_attention,
591+
PrimeHeadsAndSmallHeadDimNoTranspose) {
592+
test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17, false);
593+
}
594+
595+
TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) {
596+
test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false);
597+
}
598+
512599
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case {
230230
}
231231
};
232232

233+
struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case {
234+
int b;
235+
int s_attn;
236+
int s_v;
237+
int h;
238+
int d;
239+
size_t b_attn_stride;
240+
size_t h_attn_stride;
241+
size_t s_attn_stride;
242+
size_t b_v_stride;
243+
size_t h_v_stride;
244+
size_t s_v_stride;
245+
size_t b_v_qparams_stride;
246+
size_t h_v_qparams_stride;
247+
size_t s_v_qparams_stride;
248+
249+
std::vector<float> expected_output;
250+
251+
std::vector<float> attn_scores;
252+
253+
std::vector<float> v;
254+
std::vector<int8_t> v_qvals;
255+
std::vector<float> v_scales;
256+
std::vector<int8_t> v_zeros;
257+
258+
fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case(
259+
int b_,
260+
int s_attn_,
261+
int s_v_,
262+
int h_,
263+
int d_,
264+
size_t b_attn_stride_,
265+
size_t h_attn_stride_,
266+
size_t s_attn_stride_,
267+
size_t b_v_stride_,
268+
size_t h_v_stride_,
269+
size_t s_v_stride_,
270+
size_t b_v_qparams_stride_,
271+
size_t h_v_qparams_stride_,
272+
size_t s_v_qparams_stride_,
273+
std::vector<float> expected_output_,
274+
std::vector<float> attn_scores_,
275+
std::vector<float> v_,
276+
std::vector<int8_t> v_qvals_,
277+
std::vector<float> v_scales_,
278+
std::vector<int8_t> v_zeros_)
279+
: b(b_),
280+
s_attn(s_attn_),
281+
s_v(s_v_),
282+
h(h_),
283+
d(d_),
284+
b_attn_stride(b_attn_stride_),
285+
h_attn_stride(h_attn_stride_),
286+
s_attn_stride(s_attn_stride_),
287+
b_v_stride(b_v_stride_),
288+
h_v_stride(h_v_stride_),
289+
s_v_stride(s_v_stride_),
290+
b_v_qparams_stride(b_v_qparams_stride_),
291+
h_v_qparams_stride(h_v_qparams_stride_),
292+
s_v_qparams_stride(s_v_qparams_stride_),
293+
expected_output(expected_output_),
294+
attn_scores(attn_scores_),
295+
v(v_),
296+
v_qvals(v_qvals_),
297+
v_scales(v_scales_),
298+
v_zeros(v_zeros_) {
299+
assert(expected_output.size() == b * s_attn * h * d);
300+
assert(attn_scores.size() == b * h * s_attn * s_v);
301+
assert(v.size() == b * h * s_v * d);
302+
assert(v_qvals.size() == b * h * s_v * d);
303+
assert(v_scales.size() == b * h * s_v);
304+
assert(v_zeros.size() == b * h * s_v);
305+
}
306+
307+
static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case
308+
generate(int b, int s_attn, int s_v, int h, int d, bool transposed_v = true) {
309+
// Generate activations
310+
auto lhs = get_random_vector(b * h * s_attn * s_v, -1.0, 1.0);
311+
312+
auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] =
313+
torchao::test_utils::generate_per_token_quantized_tensor(
314+
b * h * s_v, d);
315+
// Above function produces nxk matrix and to produce kxn you need transposed
316+
// = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true
317+
// the shape should be nxk instead of kxn.
318+
319+
size_t b_attn_stride = h * s_attn * s_v;
320+
size_t h_attn_stride = s_attn * s_v;
321+
size_t s_attn_stride = s_v;
322+
323+
size_t b_v_stride = h * s_v * d;
324+
size_t h_v_stride = s_v * d;
325+
size_t s_v_stride = d;
326+
327+
size_t b_v_qparams_stride = h * s_v;
328+
size_t h_v_qparams_stride = s_v;
329+
size_t s_v_qparams_stride = 1;
330+
331+
if (!transposed_v) {
332+
h_v_stride = d;
333+
s_v_stride = h * d;
334+
335+
s_v_qparams_stride = h;
336+
h_v_qparams_stride = 1;
337+
}
338+
339+
// Compute expected output
340+
// Note that while the inputs can be in shape b x h x s_attn x s_v,
341+
// and b x h x s_v x d the output is not in b x h x s_attn x s_v
342+
// but rather b x s_attn x h x d. This is because the output of
343+
// SDPA will normally be in b x h x s_attn x d, but we want to
344+
// avoid any tranposes. Thus just aim to output in b x s_attn x h x d
345+
// This is just for testing purposes. Kernel can actually write output
346+
// in [B, H, S, D] if needed.
347+
std::vector<float> expected_output(b * s_attn * h * d);
348+
size_t b_out_stride = s_attn * h * d;
349+
size_t s_attn_out_stride = h * d;
350+
size_t h_out_stride = d;
351+
352+
for (int b_idx = 0; b_idx < b; b_idx++) {
353+
for (int s_attn_idx = 0; s_attn_idx < s_attn; s_attn_idx++) {
354+
for (int h_idx = 0; h_idx < h; h_idx++) {
355+
for (int d_idx = 0; d_idx < d; d_idx++) {
356+
float res = 0.0;
357+
for (int s_v_idx = 0; s_v_idx < s_v; s_v_idx++) {
358+
int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride +
359+
h_idx * h_attn_stride + s_v_idx;
360+
int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx +
361+
s_v_idx * s_v_stride;
362+
int rhs_scales_zp_idx = b_idx * b_v_qparams_stride +
363+
h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride;
364+
float rhs_dequant = rhs_scales[rhs_scales_zp_idx] *
365+
(rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]);
366+
367+
res += lhs[lhs_idx] * rhs_dequant;
368+
}
369+
expected_output
370+
[b_idx * b_out_stride + s_attn_idx * s_attn_out_stride +
371+
h_idx * h_out_stride + d_idx] = res;
372+
}
373+
}
374+
}
375+
}
376+
377+
// Return test case
378+
return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case(
379+
b,
380+
s_attn,
381+
s_v,
382+
h,
383+
d,
384+
b_attn_stride,
385+
h_attn_stride,
386+
s_attn_stride,
387+
b_v_stride,
388+
h_v_stride,
389+
s_v_stride,
390+
b_v_qparams_stride,
391+
h_v_qparams_stride,
392+
s_v_qparams_stride,
393+
expected_output,
394+
lhs,
395+
rhs,
396+
rhs_qvals,
397+
rhs_scales,
398+
rhs_zeros);
399+
}
400+
};
233401
} // namespace torchao
234402

235403
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)