|
18 | 18 |
|
19 | 19 | #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
20 | 20 | #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
| 21 | +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" |
21 | 22 | #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
22 | 23 |
|
23 |
| -namespace cutlass |
24 |
| -{ |
25 |
| -namespace gemm |
26 |
| -{ |
27 |
| -namespace threadblock |
28 |
| -{ |
| 24 | +namespace cutlass { |
| 25 | +namespace gemm { |
| 26 | +namespace threadblock { |
29 | 27 |
|
30 | 28 | ////////////////////////////////////////////////////////////////////////////////
|
31 | 29 |
|
@@ -378,38 +376,23 @@ template <
|
378 | 376 | struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
379 | 377 | layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
380 | 378 | {
|
381 |
| - static cutlass::arch::CacheOperation::Kind const CacheOpA = |
382 |
| - ((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global |
383 |
| - : cutlass::arch::CacheOperation::Always; |
384 |
| - |
385 |
| - static cutlass::arch::CacheOperation::Kind const CacheOpB = |
386 |
| - ((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global |
387 |
| - : cutlass::arch::CacheOperation::Always; |
| 379 | +private: |
| 380 | + using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, |
| 381 | + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, |
| 382 | + WarpShape, InstructionShape, 2, Operator>; |
388 | 383 |
|
| 384 | +public: |
389 | 385 | // Define the MmaCore components
|
390 |
| - using MmaCore = |
391 |
| - typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t, |
392 |
| - LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator, |
393 |
| - false, CacheOpA, CacheOpB>; |
| 386 | + using MmaCore = typename Mma::MmaCore; |
394 | 387 |
|
395 | 388 | // Define iterators over tiles from the A operand
|
396 |
| - using ThreadMapA = typename MmaCore::IteratorThreadMapA; |
397 |
| - using AccessTypeA = cutlass::Array<half_t, kAlignmentA>; |
398 |
| - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< |
399 |
| - cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, |
400 |
| - AccessTypeA>; |
| 389 | + using IteratorA = typename Mma::IteratorA; |
401 | 390 |
|
402 | 391 | // Define iterators over tiles from the B operand
|
403 |
| - using ThreadMapB = typename MmaCore::IteratorThreadMapB; |
404 |
| - using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>; |
405 |
| - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< |
406 |
| - cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB, |
407 |
| - AccessTypeB>; |
| 392 | + using IteratorB = typename Mma::IteratorB; |
408 | 393 |
|
409 | 394 | // Define the threadblock-scoped multistage matrix multiply
|
410 |
| - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA, |
411 |
| - typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, |
412 |
| - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>; |
| 395 | + using ThreadblockMma = typename Mma::ThreadblockMma; |
413 | 396 | };
|
414 | 397 |
|
415 | 398 | template <
|
@@ -441,38 +424,23 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
441 | 424 | layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
442 | 425 | false, SharedMemoryClear>
|
443 | 426 | {
|
444 |
| - static cutlass::arch::CacheOperation::Kind const CacheOpA = |
445 |
| - ((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global |
446 |
| - : cutlass::arch::CacheOperation::Always; |
447 |
| - |
448 |
| - static cutlass::arch::CacheOperation::Kind const CacheOpB = |
449 |
| - ((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global |
450 |
| - : cutlass::arch::CacheOperation::Always; |
| 427 | +private: |
| 428 | + using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, |
| 429 | + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, |
| 430 | + WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>; |
451 | 431 |
|
| 432 | +public: |
452 | 433 | // Define the MmaCore components
|
453 |
| - using MmaCore = |
454 |
| - typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t, |
455 |
| - LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator, |
456 |
| - false, CacheOpA, CacheOpB>; |
| 434 | + using MmaCore = typename Mma::MmaCore; |
457 | 435 |
|
458 | 436 | // Define iterators over tiles from the A operand
|
459 |
| - using ThreadMapA = typename MmaCore::IteratorThreadMapA; |
460 |
| - using AccessTypeA = cutlass::Array<half_t, kAlignmentA>; |
461 |
| - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< |
462 |
| - cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, |
463 |
| - AccessTypeA>; |
| 437 | + using IteratorA = typename Mma::IteratorA; |
464 | 438 |
|
465 | 439 | // Define iterators over tiles from the B operand
|
466 |
| - using ThreadMapB = typename MmaCore::IteratorThreadMapB; |
467 |
| - using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>; |
468 |
| - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< |
469 |
| - cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB, |
470 |
| - AccessTypeB>; |
| 440 | + using IteratorB = typename Mma::IteratorB; |
471 | 441 |
|
472 | 442 | // Define the threadblock-scoped multistage matrix multiply
|
473 |
| - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA, |
474 |
| - typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, |
475 |
| - MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>; |
| 443 | + using ThreadblockMma = typename Mma::ThreadblockMma; |
476 | 444 | };
|
477 | 445 |
|
478 | 446 | } // namespace threadblock
|
|
0 commit comments