@@ -230,6 +230,174 @@ struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case {
230
230
}
231
231
};
232
232
233
+ struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case {
234
+ int b;
235
+ int s_attn;
236
+ int s_v;
237
+ int h;
238
+ int d;
239
+ size_t b_attn_stride;
240
+ size_t h_attn_stride;
241
+ size_t s_attn_stride;
242
+ size_t b_v_stride;
243
+ size_t h_v_stride;
244
+ size_t s_v_stride;
245
+ size_t b_v_qparams_stride;
246
+ size_t h_v_qparams_stride;
247
+ size_t s_v_qparams_stride;
248
+
249
+ std::vector<float > expected_output;
250
+
251
+ std::vector<float > attn_scores;
252
+
253
+ std::vector<float > v;
254
+ std::vector<int8_t > v_qvals;
255
+ std::vector<float > v_scales;
256
+ std::vector<int8_t > v_zeros;
257
+
258
+ fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case (
259
+ int b_,
260
+ int s_attn_,
261
+ int s_v_,
262
+ int h_,
263
+ int d_,
264
+ size_t b_attn_stride_,
265
+ size_t h_attn_stride_,
266
+ size_t s_attn_stride_,
267
+ size_t b_v_stride_,
268
+ size_t h_v_stride_,
269
+ size_t s_v_stride_,
270
+ size_t b_v_qparams_stride_,
271
+ size_t h_v_qparams_stride_,
272
+ size_t s_v_qparams_stride_,
273
+ std::vector<float > expected_output_,
274
+ std::vector<float > attn_scores_,
275
+ std::vector<float > v_,
276
+ std::vector<int8_t > v_qvals_,
277
+ std::vector<float > v_scales_,
278
+ std::vector<int8_t > v_zeros_)
279
+ : b(b_),
280
+ s_attn (s_attn_),
281
+ s_v(s_v_),
282
+ h(h_),
283
+ d(d_),
284
+ b_attn_stride(b_attn_stride_),
285
+ h_attn_stride(h_attn_stride_),
286
+ s_attn_stride(s_attn_stride_),
287
+ b_v_stride(b_v_stride_),
288
+ h_v_stride(h_v_stride_),
289
+ s_v_stride(s_v_stride_),
290
+ b_v_qparams_stride(b_v_qparams_stride_),
291
+ h_v_qparams_stride(h_v_qparams_stride_),
292
+ s_v_qparams_stride(s_v_qparams_stride_),
293
+ expected_output(expected_output_),
294
+ attn_scores(attn_scores_),
295
+ v(v_),
296
+ v_qvals(v_qvals_),
297
+ v_scales(v_scales_),
298
+ v_zeros(v_zeros_) {
299
+ assert (expected_output.size () == b * s_attn * h * d);
300
+ assert (attn_scores.size () == b * h * s_attn * s_v);
301
+ assert (v.size () == b * h * s_v * d);
302
+ assert (v_qvals.size () == b * h * s_v * d);
303
+ assert (v_scales.size () == b * h * s_v);
304
+ assert (v_zeros.size () == b * h * s_v);
305
+ }
306
+
307
+ static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case
308
+ generate (int b, int s_attn, int s_v, int h, int d, bool transposed_v = true ) {
309
+ // Generate activations
310
+ auto lhs = get_random_vector (b * h * s_attn * s_v, -1.0 , 1.0 );
311
+
312
+ auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] =
313
+ torchao::test_utils::generate_per_token_quantized_tensor (
314
+ b * h * s_v, d);
315
+ // Above function produces nxk matrix and to produce kxn you need transposed
316
+ // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true
317
+ // the shape should be nxk instead of kxn.
318
+
319
+ size_t b_attn_stride = h * s_attn * s_v;
320
+ size_t h_attn_stride = s_attn * s_v;
321
+ size_t s_attn_stride = s_v;
322
+
323
+ size_t b_v_stride = h * s_v * d;
324
+ size_t h_v_stride = s_v * d;
325
+ size_t s_v_stride = d;
326
+
327
+ size_t b_v_qparams_stride = h * s_v;
328
+ size_t h_v_qparams_stride = s_v;
329
+ size_t s_v_qparams_stride = 1 ;
330
+
331
+ if (!transposed_v) {
332
+ h_v_stride = d;
333
+ s_v_stride = h * d;
334
+
335
+ s_v_qparams_stride = h;
336
+ h_v_qparams_stride = 1 ;
337
+ }
338
+
339
+ // Compute expected output
340
+ // Note that while the inputs can be in shape b x h x s_attn x s_v,
341
+ // and b x h x s_v x d the output is not in b x h x s_attn x s_v
342
+ // but rather b x s_attn x h x d. This is because the output of
343
+ // SDPA will normally be in b x h x s_attn x d, but we want to
344
+ // avoid any tranposes. Thus just aim to output in b x s_attn x h x d
345
+ // This is just for testing purposes. Kernel can actually write output
346
+ // in [B, H, S, D] if needed.
347
+ std::vector<float > expected_output (b * s_attn * h * d);
348
+ size_t b_out_stride = s_attn * h * d;
349
+ size_t s_attn_out_stride = h * d;
350
+ size_t h_out_stride = d;
351
+
352
+ for (int b_idx = 0 ; b_idx < b; b_idx++) {
353
+ for (int s_attn_idx = 0 ; s_attn_idx < s_attn; s_attn_idx++) {
354
+ for (int h_idx = 0 ; h_idx < h; h_idx++) {
355
+ for (int d_idx = 0 ; d_idx < d; d_idx++) {
356
+ float res = 0.0 ;
357
+ for (int s_v_idx = 0 ; s_v_idx < s_v; s_v_idx++) {
358
+ int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride +
359
+ h_idx * h_attn_stride + s_v_idx;
360
+ int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx +
361
+ s_v_idx * s_v_stride;
362
+ int rhs_scales_zp_idx = b_idx * b_v_qparams_stride +
363
+ h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride;
364
+ float rhs_dequant = rhs_scales[rhs_scales_zp_idx] *
365
+ (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]);
366
+
367
+ res += lhs[lhs_idx] * rhs_dequant;
368
+ }
369
+ expected_output
370
+ [b_idx * b_out_stride + s_attn_idx * s_attn_out_stride +
371
+ h_idx * h_out_stride + d_idx] = res;
372
+ }
373
+ }
374
+ }
375
+ }
376
+
377
+ // Return test case
378
+ return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case (
379
+ b,
380
+ s_attn,
381
+ s_v,
382
+ h,
383
+ d,
384
+ b_attn_stride,
385
+ h_attn_stride,
386
+ s_attn_stride,
387
+ b_v_stride,
388
+ h_v_stride,
389
+ s_v_stride,
390
+ b_v_qparams_stride,
391
+ h_v_qparams_stride,
392
+ s_v_qparams_stride,
393
+ expected_output,
394
+ lhs,
395
+ rhs,
396
+ rhs_qvals,
397
+ rhs_scales,
398
+ rhs_zeros);
399
+ }
400
+ };
233
401
} // namespace torchao
234
402
235
403
#endif // defined(__aarch64__) || defined(__ARM_NEON)
0 commit comments