Skip to content

Commit feac956

Browse files
committed
Change wint2 to ColumnMajorTileInterleave.
Change-Id: I593cbe36f991c0c5044989d65f0014087587c624
1 parent 2efcfbb commit feac956

File tree

7 files changed

+72
-51
lines changed

7 files changed

+72
-51
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,18 @@ template <typename TypeA, typename Arch>
133133
template <typename TypeA, typename Arch>
134134
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
135135
{
136-
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
137-
using Layout = layout::ColumnMajor;
138-
static constexpr int ElementsPerAccess = 8; // at least 4-bytes
139-
using Operator = cutlass::arch::OpMultiplyAdd;
136+
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value; // 64
137+
138+
private:
139+
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
140+
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8
141+
142+
public:
143+
// using Layout = layout::ColumnMajor;
144+
// static constexpr int ElementsPerAccess = 16; // at least 4-bytes
145+
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
146+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value; // 64
147+
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
140148
};
141149

142150
template <typename TypeA, typename Arch>

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
106106
static_assert(platform::is_same<ElementB, uint2b_t>::value,
107107
"Element B must be uint2b_t");
108108

109+
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
110+
"Mma multistage must dequantize after ldsm");
111+
109112
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
110113
? cutlass::arch::CacheOperation::Global
111114
: cutlass::arch::CacheOperation::Always;
@@ -117,8 +120,8 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
117120
// Define the MmaCore components
118121
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
119122
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
120-
ElementA, LayoutA, ElementA, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
121-
Operator, false, CacheOpA, CacheOpB>;
123+
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
124+
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
122125

123126
// Define iterators over tiles from the A operand
124127
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
@@ -127,17 +130,39 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
127130
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
128131
AccessTypeA>;
129132

130-
// Define iterators over tiles from the B operand
133+
private:
134+
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
135+
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
136+
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
137+
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
138+
131139
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
140+
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
141+
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
142+
143+
using IteratorShapeB = MatrixShape<
144+
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
145+
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
146+
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
147+
ThreadMapB::kThreads,
148+
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
149+
WarpArrangement::kStrided / kColumnsInterleaved>,
150+
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
151+
152+
public:
153+
// Define iterators over tiles from the B operand
132154
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
133155
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
134-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
156+
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
135157
AccessTypeB>;
136158

137159
// Define the threadblock-scoped multistage matrix multiply
138-
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
139-
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
140-
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
160+
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
161+
typename MmaCore::Shape,
162+
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
163+
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
164+
ElementAccumulator, layout::RowMajor,
165+
typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
141166
};
142167

143168
} // namespace threadblock

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,9 @@ class Wint2xMmaMultistage :
266266
if (smem_read_stage_idx_ == Base::kStages) {
267267
// Wrap back around to the 'start' of the circular buffer in shared memory
268268
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
269-
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
269+
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
270270
smem_read_stage_idx_ = 0;
271271
}
272-
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
273272
}
274273

275274
/// Advance global memory read-iterators and shared memory write-iterators to the stage
@@ -566,16 +565,6 @@ class Wint2xMmaMultistage :
566565
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
567566
++this->warp_tile_iterator_A_;
568567

569-
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
570-
// Unpack and dequant the first stage of B.
571-
int unpack_stage = stage - Base::kStages + 2;
572-
//tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
573-
// column_wise_smem_ptr_B_, unpack_stage);
574-
575-
// Copy dequatized data to shared memory used by mma core.
576-
//copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
577-
}
578-
579568
// Load the next warp-tile's B fragment from shared memory
580569
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
581570
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
@@ -617,13 +606,10 @@ class Wint2xMmaMultistage :
617606
// global->shared fragment copies
618607
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
619608
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
609+
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
620610

621611
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
622-
623-
if (warp_mma_k == 0) {
624-
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
625-
column_wise_smem_ptr_B_, stage);
626-
}
612+
copy_tiles_and_advance_B<false>(iterator_B, group_start_iteration_B);
627613
}
628614

629615
// The second-to-last warp-tile also:
@@ -632,8 +618,10 @@ class Wint2xMmaMultistage :
632618
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
633619
// Performs the last warp-tile's share of global->shared fragment copies
634620
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
621+
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
635622

636623
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
624+
copy_tiles_and_advance_B<false>(iterator_B, group_start_iteration_B);
637625

638626
// Inserts a memory fence between stages of cp.async instructions.
639627
cutlass::arch::cp_async_fence();
@@ -648,7 +636,7 @@ class Wint2xMmaMultistage :
648636
// Disable global fetching when done with global fetch iterations
649637
--gemm_k_iterations;
650638
iterator_A.clear_mask(gemm_k_iterations == 0);
651-
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
639+
iterator_B.clear_mask(gemm_k_iterations == 0);
652640
}
653641

654642
// The last warp-tile also converts the shared memory fragments used by
@@ -675,12 +663,8 @@ class Wint2xMmaMultistage :
675663
IteratorB &iterator_B,
676664
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
677665
{
678-
#if 0
679666
PipeState pipe_state;
680667

681-
// Unpack and dequant the first stage of B.
682-
//tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
683-
684668
// Disable global fetching if done with global fetch iterations
685669
iterator_A.clear_mask(gemm_k_iterations == 0);
686670
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
@@ -690,9 +674,6 @@ class Wint2xMmaMultistage :
690674
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
691675
++this->warp_tile_iterator_A_;
692676

693-
// Copy dequatized data to shared memory used by mma core.
694-
//copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
695-
696677
// Load first warp-tile's B fragment from shared memory
697678
this->warp_tile_iterator_B_.set_kgroup_index(0);
698679
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
@@ -709,6 +690,7 @@ class Wint2xMmaMultistage :
709690
pipe_state.tmp_accum_.clear();
710691
}
711692

693+
#if 0
712694
int stage = Base::kStages - 1;
713695

714696
// Mainloop
@@ -723,6 +705,7 @@ class Wint2xMmaMultistage :
723705
gemm_k_iterations,
724706
stage);
725707
stage += 1;
708+
break;
726709
}
727710

728711
if (Detail::kStagedAccumulation) {
@@ -766,8 +749,7 @@ class Wint2xMmaMultistage :
766749
else
767750
{
768751
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
769-
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
770-
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
752+
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
771753
}
772754
smem_read_stage_idx_ = smem_write_stage_idx_;
773755
}

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,9 @@
4141
#include "cutlass_extensions/arch/mma.h"
4242
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
4343

44-
namespace cutlass
45-
{
46-
namespace gemm
47-
{
48-
namespace warp
49-
{
44+
namespace cutlass {
45+
namespace gemm {
46+
namespace warp {
5047

5148
/////////////////////////////////////////////////////////////////////////////////////////////////
5249

@@ -81,7 +78,7 @@ struct DefaultMmaTensorOp<WarpShape_, InstructionShape_, ElementA, LayoutA, Elem
8178
// Shape for computing the FP16s
8279
using ComputeInstructionShape = InstructionShape_;
8380

84-
// Chosen so we get K=16 for int8 and K=32 for int4.
81+
// Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2.
8582
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
8683

8784
// Shape for loading the narrow data type from shared memory

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ class MmaTensorOpComputeBWithF16
295295
assert(0);
296296
#endif
297297
}
298+
299+
/// Transform the mma operands to the required types
300+
CUTLASS_DEVICE
301+
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
302+
FragmentA const &A, FragmentB const &B) const {}
298303
};
299304

300305
/////////////////////////////////////////////////////////////////////////////////////////////////

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -715,8 +715,8 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
715715
std::vector<CutlassGemmConfig> candidate_configs =
716716
get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true);
717717

718-
static constexpr int warm_time = 5;
719-
static constexpr int test_time = 10;
718+
static constexpr int warm_time = 0;
719+
static constexpr int test_time = 1;
720720
auto& gemmConfigManager = GemmConfigManager::Instance();
721721
constexpr GemmDataType dtype = getGemmDataType<T>();
722722
constexpr GemmDataType wdtype = getGemmDataType<WeightType>();
@@ -735,8 +735,10 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
735735
std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows),
736736
gemmConfigManager.getMaxProfileM());
737737
bool find_one = false;
738-
size_t num_candidate_configs_size = candidate_configs.size();
739-
for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) {
738+
size_t num_candidate_configs_size = 2;//candidate_configs.size();
739+
// for (size_t ii = 0; ii < num_candidate_configs_size; ++ii)
740+
{
741+
size_t ii = 1;
740742
try {
741743
for (int i = 0; i < warm_time; i++) {
742744
dispatch_to_arch<EpilogueTag>(A,
@@ -780,7 +782,7 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
780782
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
781783
check_cuda_error(cudaEventDestroy(start));
782784
check_cuda_error(cudaEventDestroy(stop));
783-
//std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
785+
std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl;
784786
if (elapsed < best_time) {
785787
best_id = ii;
786788
best_time = elapsed;
@@ -801,6 +803,7 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
801803
}
802804
}
803805

806+
#if 0
804807
dispatch_to_arch<EpilogueTag>(A,
805808
B,
806809
weight_scales,
@@ -814,6 +817,7 @@ void MoeGemmRunner<T, WeightQuantTraits>::run_gemm<EpilogueTag>(
814817
quant_args_B,
815818
chosen_config,
816819
stream);
820+
#endif
817821
}
818822

819823
template <typename T, typename WeightQuantTraits>

custom_ops/gpu_ops/moe/moe_ffn_wint2.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "moe/fast_hardamard_kernel.h"
2121
#include "moe/fused_moe_helper.h"
2222

23-
#define _GROUP_GEMM_ONLY 0
23+
#define _GROUP_GEMM_ONLY 1
2424

2525
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
2626
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,

0 commit comments

Comments
 (0)