Skip to content

Commit 2efcfbb

Browse files
committed
Unify default_wint2x_mma.
Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7
1 parent fbd86c8 commit 2efcfbb

File tree

3 files changed

+186
-103
lines changed

3 files changed

+186
-103
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h

Lines changed: 22 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@
1818

1919
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
2020
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
21+
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
2122
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
2223

23-
namespace cutlass
24-
{
25-
namespace gemm
26-
{
27-
namespace threadblock
28-
{
24+
namespace cutlass {
25+
namespace gemm {
26+
namespace threadblock {
2927

3028
////////////////////////////////////////////////////////////////////////////////
3129

@@ -378,38 +376,23 @@ template <
378376
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
379377
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
380378
{
381-
static cutlass::arch::CacheOperation::Kind const CacheOpA =
382-
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
383-
: cutlass::arch::CacheOperation::Always;
384-
385-
static cutlass::arch::CacheOperation::Kind const CacheOpB =
386-
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
387-
: cutlass::arch::CacheOperation::Always;
379+
private:
380+
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
381+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
382+
WarpShape, InstructionShape, 2, Operator>;
388383

384+
public:
389385
// Define the MmaCore components
390-
using MmaCore =
391-
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
392-
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
393-
false, CacheOpA, CacheOpB>;
386+
using MmaCore = typename Mma::MmaCore;
394387

395388
// Define iterators over tiles from the A operand
396-
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
397-
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
398-
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
399-
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
400-
AccessTypeA>;
389+
using IteratorA = typename Mma::IteratorA;
401390

402391
// Define iterators over tiles from the B operand
403-
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
404-
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
405-
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
406-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
407-
AccessTypeB>;
392+
using IteratorB = typename Mma::IteratorB;
408393

409394
// Define the threadblock-scoped multistage matrix multiply
410-
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
411-
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
412-
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
395+
using ThreadblockMma = typename Mma::ThreadblockMma;
413396
};
414397

415398
template <
@@ -441,38 +424,23 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
441424
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
442425
false, SharedMemoryClear>
443426
{
444-
static cutlass::arch::CacheOperation::Kind const CacheOpA =
445-
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
446-
: cutlass::arch::CacheOperation::Always;
447-
448-
static cutlass::arch::CacheOperation::Kind const CacheOpB =
449-
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
450-
: cutlass::arch::CacheOperation::Always;
427+
private:
428+
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
429+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
430+
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
451431

432+
public:
452433
// Define the MmaCore components
453-
using MmaCore =
454-
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
455-
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
456-
false, CacheOpA, CacheOpB>;
434+
using MmaCore = typename Mma::MmaCore;
457435

458436
// Define iterators over tiles from the A operand
459-
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
460-
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
461-
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
462-
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
463-
AccessTypeA>;
437+
using IteratorA = typename Mma::IteratorA;
464438

465439
// Define iterators over tiles from the B operand
466-
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
467-
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
468-
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
469-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
470-
AccessTypeB>;
440+
using IteratorB = typename Mma::IteratorB;
471441

472442
// Define the threadblock-scoped multistage matrix multiply
473-
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
474-
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
475-
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
443+
using ThreadblockMma = typename Mma::ThreadblockMma;
476444
};
477445

478446
} // namespace threadblock

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "cutlass/gemm/threadblock/default_mma.h"
2020
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
2121
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
22-
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
22+
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
2323

2424
namespace cutlass {
2525
namespace gemm {
@@ -379,38 +379,23 @@ template <
379379
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
380380
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
381381
{
382-
static cutlass::arch::CacheOperation::Kind const CacheOpA =
383-
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
384-
: cutlass::arch::CacheOperation::Always;
385-
386-
static cutlass::arch::CacheOperation::Kind const CacheOpB =
387-
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
388-
: cutlass::arch::CacheOperation::Always;
382+
private:
383+
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
384+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
385+
WarpShape, InstructionShape, 2, Operator>;
389386

387+
public:
390388
// Define the MmaCore components
391-
using MmaCore =
392-
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
393-
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
394-
false, CacheOpA, CacheOpB>;
389+
using MmaCore = typename Mma::MmaCore;
395390

396391
// Define iterators over tiles from the A operand
397-
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
398-
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
399-
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
400-
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
401-
AccessTypeA>;
392+
using IteratorA = typename Mma::IteratorA;
402393

403394
// Define iterators over tiles from the B operand
404-
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
405-
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
406-
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
407-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
408-
AccessTypeB>;
395+
using IteratorB = typename Mma::IteratorB;
409396

410397
// Define the threadblock-scoped multistage matrix multiply
411-
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
412-
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
413-
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
398+
using ThreadblockMma = typename Mma::ThreadblockMma;
414399
};
415400

416401
template <
@@ -442,38 +427,23 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
442427
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
443428
false, SharedMemoryClear>
444429
{
445-
static cutlass::arch::CacheOperation::Kind const CacheOpA =
446-
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
447-
: cutlass::arch::CacheOperation::Always;
448-
449-
static cutlass::arch::CacheOperation::Kind const CacheOpB =
450-
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
451-
: cutlass::arch::CacheOperation::Always;
430+
private:
431+
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
432+
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
433+
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
452434

435+
public:
453436
// Define the MmaCore components
454-
using MmaCore =
455-
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
456-
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
457-
false, CacheOpA, CacheOpB>;
437+
using MmaCore = typename Mma::MmaCore;
458438

459439
// Define iterators over tiles from the A operand
460-
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
461-
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
462-
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
463-
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
464-
AccessTypeA>;
440+
using IteratorA = typename Mma::IteratorA;
465441

466442
// Define iterators over tiles from the B operand
467-
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
468-
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
469-
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
470-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
471-
AccessTypeB>;
443+
using IteratorB = typename Mma::IteratorB;
472444

473445
// Define the threadblock-scoped multistage matrix multiply
474-
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
475-
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
476-
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
446+
using ThreadblockMma = typename Mma::ThreadblockMma;
477447
};
478448

479449
} // namespace threadblock
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include "cutlass_extensions/arch/mma.h"
21+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
22+
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
23+
24+
namespace cutlass {
25+
namespace gemm {
26+
namespace threadblock {
27+
28+
////////////////////////////////////////////////////////////////////////////////
29+
30+
template <
31+
/// Element type for A matrix operand
32+
typename ElementA_,
33+
/// Layout type for A matrix operand
34+
typename LayoutA_,
35+
/// Access granularity of A matrix in units of elements
36+
int kAlignmentA,
37+
/// Element type for B matrix operand
38+
typename ElementB_,
39+
/// Layout type for B matrix operand
40+
typename LayoutB_,
41+
/// Access granularity of B matrix in units of elements
42+
int kAlignmentB,
43+
/// Element type for internal accumulation
44+
typename ElementAccumulator_,
45+
/// Layout type for C and D matrix operands
46+
typename LayoutC_,
47+
/// Operator class tag
48+
typename OperatorClass_,
49+
/// Tag indicating architecture to tune for
50+
typename ArchTag_,
51+
/// Threadblock-level tile size (concept: GemmShape)
52+
typename ThreadblockShape_,
53+
/// Warp-level tile size (concept: GemmShape)
54+
typename WarpShape_,
55+
/// Instruction-level tile size (concept: GemmShape)
56+
typename InstructionShape_,
57+
/// Number of stages used in the pipelined mainloop
58+
int Stages,
59+
/// Operation performed by GEMM
60+
typename Operator_,
61+
/// Use zfill or predicate for out-of-bound cp.async
62+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
63+
struct DefaultWint2xMma;
64+
65+
////////////////////////////////////////////////////////////////////////////////
66+
67+
template <
68+
/// Type for element A
69+
typename ElementA,
70+
/// Layout type for A matrix operand
71+
typename LayoutA,
72+
/// Access granularity of A matrix in units of elements
73+
int kAlignmentA,
74+
/// Type for element B
75+
typename ElementB,
76+
/// Layout type for B matrix operand
77+
typename LayoutB,
78+
/// Access granularity of B matrix in units of elements
79+
int kAlignmentB,
80+
/// Element type for internal accumulation
81+
typename ElementAccumulator,
82+
/// Operator class tag
83+
typename OperatorClass,
84+
/// Tag indicating architecture to tune for
85+
typename ArchTag,
86+
/// Threadblock-level tile size (concept: GemmShape)
87+
typename ThreadblockShape,
88+
/// Warp-level tile size (concept: GemmShape)
89+
typename WarpShape,
90+
/// Instruction-level tile size (concept: GemmShape)
91+
typename InstructionShape,
92+
/// Stages in GEMM
93+
int kStages,
94+
/// Operator performed by GEMM
95+
typename Operator,
96+
/// Use zfill or predicate for out-of-bound cp.async
97+
SharedMemoryClearOption SharedMemoryClear>
98+
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
99+
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
100+
kStages, Operator, SharedMemoryClear>
101+
{
102+
103+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
104+
"Element A must be fp16 or bf16");
105+
106+
static_assert(platform::is_same<ElementB, uint2b_t>::value,
107+
"Element B must be uint2b_t");
108+
109+
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
110+
? cutlass::arch::CacheOperation::Global
111+
: cutlass::arch::CacheOperation::Always;
112+
113+
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
114+
? cutlass::arch::CacheOperation::Global
115+
: cutlass::arch::CacheOperation::Always;
116+
117+
// Define the MmaCore components
118+
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
119+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
120+
ElementA, LayoutA, ElementA, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
121+
Operator, false, CacheOpA, CacheOpB>;
122+
123+
// Define iterators over tiles from the A operand
124+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
125+
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
126+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
127+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
128+
AccessTypeA>;
129+
130+
// Define iterators over tiles from the B operand
131+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
132+
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
133+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
134+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
135+
AccessTypeB>;
136+
137+
// Define the threadblock-scoped multistage matrix multiply
138+
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
139+
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
140+
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
141+
};
142+
143+
} // namespace threadblock
144+
} // namespace gemm
145+
} // namespace cutlass

0 commit comments

Comments
 (0)