Skip to content

Commit fbd86c8

Browse files
committed
Change wint2 to ColumnMajor.
Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2
1 parent d5af789 commit fbd86c8

File tree

9 files changed

+129
-57
lines changed

9 files changed

+129
-57
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ template <typename TypeA, typename Arch>
134134
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
135135
{
136136
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
137-
using Layout = layout::RowMajor;
138-
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
137+
using Layout = layout::ColumnMajor;
138+
static constexpr int ElementsPerAccess = 8; // at least 4-bytes
139139
using Operator = cutlass::arch::OpMultiplyAdd;
140140
};
141141

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAli
383383
: cutlass::arch::CacheOperation::Always;
384384

385385
static cutlass::arch::CacheOperation::Kind const CacheOpB =
386-
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
386+
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
387387
: cutlass::arch::CacheOperation::Always;
388388

389389
// Define the MmaCore components
@@ -401,9 +401,9 @@ struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAli
401401

402402
// Define iterators over tiles from the B operand
403403
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
404-
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
404+
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
405405
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
406-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
406+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
407407
AccessTypeB>;
408408

409409
// Define the threadblock-scoped multistage matrix multiply
@@ -446,7 +446,7 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
446446
: cutlass::arch::CacheOperation::Always;
447447

448448
static cutlass::arch::CacheOperation::Kind const CacheOpB =
449-
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
449+
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
450450
: cutlass::arch::CacheOperation::Always;
451451

452452
// Define the MmaCore components
@@ -464,9 +464,9 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
464464

465465
// Define iterators over tiles from the B operand
466466
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
467-
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
467+
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
468468
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
469-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
469+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
470470
AccessTypeB>;
471471

472472
// Define the threadblock-scoped multistage matrix multiply

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB,
384384
: cutlass::arch::CacheOperation::Always;
385385

386386
static cutlass::arch::CacheOperation::Kind const CacheOpB =
387-
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
387+
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
388388
: cutlass::arch::CacheOperation::Always;
389389

390390
// Define the MmaCore components
@@ -402,9 +402,9 @@ struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB,
402402

403403
// Define iterators over tiles from the B operand
404404
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
405-
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
405+
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
406406
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
407-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
407+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
408408
AccessTypeB>;
409409

410410
// Define the threadblock-scoped multistage matrix multiply
@@ -447,7 +447,7 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
447447
: cutlass::arch::CacheOperation::Always;
448448

449449
static cutlass::arch::CacheOperation::Kind const CacheOpB =
450-
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
450+
((sizeof_bits<uint2b_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
451451
: cutlass::arch::CacheOperation::Always;
452452

453453
// Define the MmaCore components
@@ -465,9 +465,9 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
465465

466466
// Define iterators over tiles from the B operand
467467
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
468-
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
468+
using AccessTypeB = cutlass::Array<uint2b_t, kAlignmentB>;
469469
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
470-
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
470+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, uint2b_t, LayoutB, 0, ThreadMapB,
471471
AccessTypeB>;
472472

473473
// Define the threadblock-scoped multistage matrix multiply

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ class Wint2xMmaBase {
104104
using TensorRefB =
105105
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
106106

107-
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
108-
109107
static_assert(kWarpGemmIterations > 1,
110108
"The pipelined structure requires at least two warp-level "
111109
"GEMM operations.");
@@ -130,12 +128,11 @@ class Wint2xMmaBase {
130128
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
131129

132130
/// Shape of the B matrix operand in shared memory
133-
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
131+
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
134132
Shape::kN + Policy::SmemPaddingB::kColumn>;
135133

136134
// w uint8; local_scale uint8;
137-
constexpr static int kZippedRowsPerStages =
138-
Shape::kK / 4 + (Shape::kK + 127) / 128;
135+
constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127) / 128;
139136

140137
// code_scale float; code_zp float; super_scale ElementB
141138
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ template <
9090
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
9191
/// Used for partial specialization
9292
typename Enable = bool>
93-
class Wint2xMmaMultistage :
93+
class Wint2xMmaMultistage :
9494
public Wint2xMmaBase<Shape_, Policy_, Stages> {
9595
public:
9696
///< Base class
@@ -282,20 +282,20 @@ class Wint2xMmaMultistage :
282282
{
283283
// Advance global iterators
284284
iterator_A.add_tile_offset({0, 1});
285-
//iterator_B.add_tile_offset({1, 0});
286-
tile_dequanter_B.AddTileOffset({1, 0});
285+
iterator_B.add_tile_offset({1, 0});
286+
//tile_dequanter_B.AddTileOffset({1, 0});
287287

288288
// Advance shared iterators
289289
smem_iterator_A_.add_tile_offset({0, 1});
290-
//smem_iterator_B_.add_tile_offset({1, 0});
290+
smem_iterator_B_.add_tile_offset({1, 0});
291291

292292
// Increment shared memory write stage index
293293
++smem_write_stage_idx_;
294294

295295
if (smem_write_stage_idx_ == Base::kStages) {
296296
// Wrap back around to the 'start' of the circular buffer in shared memory
297297
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
298-
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
298+
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
299299
smem_write_stage_idx_ = 0;
300300
}
301301
}
@@ -476,8 +476,11 @@ class Wint2xMmaMultistage :
476476
copy_tiles_and_advance_per_stage_A(iterator_A);
477477

478478
// Async copy zipped B to shared memory.
479-
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
480-
column_wise_smem_ptr_B_, stage);
479+
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
480+
481+
// TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
482+
//tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
483+
// column_wise_smem_ptr_B_, stage);
481484

482485
// Move to the next write stage
483486
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
@@ -566,11 +569,11 @@ class Wint2xMmaMultistage :
566569
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
567570
// Unpack and dequant the first stage of B.
568571
int unpack_stage = stage - Base::kStages + 2;
569-
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
570-
column_wise_smem_ptr_B_, unpack_stage);
572+
//tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
573+
// column_wise_smem_ptr_B_, unpack_stage);
571574

572575
// Copy dequatized data to shared memory used by mma core.
573-
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
576+
//copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
574577
}
575578

576579
// Load the next warp-tile's B fragment from shared memory
@@ -672,10 +675,11 @@ class Wint2xMmaMultistage :
672675
IteratorB &iterator_B,
673676
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
674677
{
678+
#if 0
675679
PipeState pipe_state;
676680

677681
// Unpack and dequant the first stage of B.
678-
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
682+
//tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
679683

680684
// Disable global fetching if done with global fetch iterations
681685
iterator_A.clear_mask(gemm_k_iterations == 0);
@@ -687,7 +691,7 @@ class Wint2xMmaMultistage :
687691
++this->warp_tile_iterator_A_;
688692

689693
// Copy dequatized data to shared memory used by mma core.
690-
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
694+
//copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
691695

692696
// Load first warp-tile's B fragment from shared memory
693697
this->warp_tile_iterator_B_.set_kgroup_index(0);
@@ -730,6 +734,7 @@ class Wint2xMmaMultistage :
730734
cutlass::arch::cp_async_fence();
731735
cutlass::arch::cp_async_wait<0>();
732736
__syncthreads();
737+
#endif
733738
}
734739

735740
/// Prepares the class for another prologue.
@@ -794,7 +799,7 @@ class Wint2xMmaMultistage :
794799
accum = src_accum;
795800

796801
// Perform the MAC-iterations
797-
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
802+
//gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
798803
}
799804
};
800805

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ struct TileDequanter {
5555
bool need_preload{true};
5656
UnzipAndDequantFunctor unzip_functor;
5757

58+
CUTLASS_DEVICE
59+
TileDequanter() {}
60+
5861
CUTLASS_DEVICE
5962
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
6063
const cutlass::MatrixCoord &extent,

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ namespace cutlass {
5252
namespace gemm {
5353
namespace kernel {
5454

55+
template <typename Layout> std::string GetCutlassLayoutString() {
56+
if (std::is_same<Layout, cutlass::layout::RowMajor>::value) {
57+
return "cutlass::layout::RowMajor";
58+
} else if (std::is_same<Layout, cutlass::layout::ColumnMajor>::value) {
59+
return "cutlass::layout::ColumnMajor";
60+
}
61+
return "unknown";
62+
}
63+
5564
/////////////////////////////////////////////////////////////////////////////////////////////////
5665
// This section exists to that we can use the same kernel code for regular gemm
5766
// and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma
@@ -282,6 +291,27 @@ struct MoeFCGemm {
282291
platform::is_same<uint4b_t, ElementB>::value) {
283292
assert(weight_scales);
284293
}
294+
295+
CUTLASS_TRACE_HOST("[Arguments] problem_count: " << problem_count << ", threadblock_count: " << threadblock_count << ", gemm_n: " << gemm_n << ", gemm_k: " << gemm_k);
296+
CUTLASS_TRACE_HOST("[Arguments] ptr_A: " << static_cast<void const*>(ptr_A));
297+
CUTLASS_TRACE_HOST("[Arguments] ptr_B: " << static_cast<void const*>(ptr_B));
298+
CUTLASS_TRACE_HOST("[Arguments] ptr_C: " << static_cast<void const*>(ptr_C));
299+
CUTLASS_TRACE_HOST("[Arguments] ptr_D: " << static_cast<void*>(ptr_D));
300+
CUTLASS_TRACE_HOST("[Arguments] weight_scales: " << static_cast<void const*>(weight_scales));
301+
CUTLASS_TRACE_HOST("[Arguments] total_rows_before_expert: " << static_cast<void*>(total_rows_before_expert));
302+
CUTLASS_TRACE_HOST("[Arguments] local_scale: " << static_cast<void const*>(local_scale));
303+
CUTLASS_TRACE_HOST("[Arguments] code_scale: " << static_cast<void const*>(code_scale));
304+
CUTLASS_TRACE_HOST("[Arguments] code_zp: " << static_cast<void const*>(code_zp));
305+
CUTLASS_TRACE_HOST("[Arguments] quant_method: " << static_cast<int>(quant_method));
306+
CUTLASS_TRACE_HOST("[Arguments] LayoutA: " << GetCutlassLayoutString<LayoutA>());
307+
CUTLASS_TRACE_HOST("[Arguments] LayoutB: " << GetCutlassLayoutString<LayoutB>());
308+
CUTLASS_TRACE_HOST("[Arguments] LayoutC: " << GetCutlassLayoutString<LayoutC>());
309+
CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorA::AccessType::kElements:" << Mma::IteratorA::AccessType::kElements);
310+
CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorB::AccessType::kElements:" << Mma::IteratorB::AccessType::kElements);
311+
CUTLASS_TRACE_HOST("[Arguments] SharedStorage Information:");
312+
CUTLASS_TRACE_HOST(" - ProblemVisitor::SharedStorage: " << sizeof(typename ProblemVisitor::SharedStorage) << " bytes");
313+
CUTLASS_TRACE_HOST(" - Mma::SharedStorage: " << sizeof(typename Mma::SharedStorage) << " bytes");
314+
CUTLASS_TRACE_HOST(" - Epilogue::SharedStorage: " << sizeof(typename Epilogue::SharedStorage) << " bytes");
285315
}
286316
};
287317

@@ -835,6 +865,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
835865
int32_t problem_idx = problem_visitor.problem_index();
836866
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
837867

868+
CUTLASS_TRACE_DEVICE(" problem_idx: %d, cta_idx: %d, problem_size: {%d, %d, %d}",
869+
problem_idx, cta_idx, static_cast<int>(problem_size.m()), static_cast<int>(problem_size.n()), static_cast<int>(problem_size.k()));
870+
871+
if (problem_idx > 2) {
872+
break;
873+
}
874+
838875
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
839876

840877
// threadblock_offset of C
@@ -879,16 +916,16 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
879916
platform::is_same<layout::RowMajor, LayoutB>::value
880917
? gemm_n
881918
: gemm_k * kInterleave;
882-
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
919+
//typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
883920

884921
// the begin threadblock_offset of B, which holds the same column id with C
885922
cutlass::MatrixCoord tb_offset_B{0,
886923
threadblock_offset.n() / kInterleave};
887924

888925
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
889-
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
926+
//cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
890927

891-
MmaElementB* smem_unzip_B_ptr = nullptr;
928+
/*MmaElementB* smem_unzip_B_ptr = nullptr;
892929
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
893930
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
894931
}
@@ -901,7 +938,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
901938
weight_scale_ptr,
902939
tb_offset_scale,
903940
quant_args);
904-
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
941+
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();*/
942+
TileDequanterB tile_dequanter_B;
943+
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
905944

906945
// Compute position within threadblock
907946
int thread_idx = threadIdx.x;
@@ -914,11 +953,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
914953
tb_offset_A);
915954

916955
typename Mma::IteratorB iterator_B(
917-
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
956+
LayoutB(ldm_B),
918957
ptr_B,
919-
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
958+
extent_B,
920959
thread_idx,
921-
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
960+
tb_offset_B);
922961

923962
typename Mma::FragmentC accumulators;
924963

0 commit comments

Comments
 (0)