From fbd86c85581bbfc11ec4e843db3ef3b8362ef76c Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Tue, 1 Jul 2025 14:57:32 +0800 Subject: [PATCH 01/11] Change wint2 to ColumnMajor. Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2 --- .../gemm/kernel/mixed_gemm_B_layout.h | 4 +- .../gemm/threadblock/default_mma.h | 12 ++--- .../gemm/threadblock/default_mma_bf16.h | 12 ++--- .../gemm/threadblock/wint2x_mma_base.h | 7 +-- .../gemm/threadblock/wint2x_mma_multistage.h | 31 ++++++----- .../gemm/threadblock/wint2x_tile_dequanter.h | 3 ++ .../moe_gemm/fused_moe_cutlass_kernel.h | 53 ++++++++++++++++--- .../fused_moe_gemm_kernels_template.h | 36 +++++++------ custom_ops/gpu_ops/moe/moe_ffn_wint2.cu | 28 +++++++++- 9 files changed, 129 insertions(+), 57 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 40f128b7a0..167bf18cfc 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -134,8 +134,8 @@ template struct LayoutDetailsB= 75>::type> { static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 8; // at least 4-bytes using Operator = cutlass::arch::OpMultiplyAdd; }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index bc395d04db..19a5e8fdaf 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -383,7 +383,7 @@ struct DefaultMma::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; // Define the MmaCore components @@ -401,9 +401,9 @@ struct DefaultMma; + using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, + cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply @@ -446,7 +446,7 @@ struct DefaultMma::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; // Define the MmaCore components @@ -464,9 +464,9 @@ struct DefaultMma; + using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, + cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 5d2c311704..c853532059 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -384,7 +384,7 @@ struct DefaultMma::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; // Define the MmaCore components @@ -402,9 +402,9 @@ struct DefaultMma; + using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply @@ -447,7 +447,7 @@ struct DefaultMma::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; // Define the MmaCore components @@ -465,9 +465,9 @@ struct DefaultMma; + using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 6dd55b647a..7dec56be29 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -104,8 +104,6 @@ class Wint2xMmaBase { using TensorRefB = TensorRef; - // using TensorRefZippedB = TensorRef; - static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " "GEMM operations."); @@ -130,12 +128,11 @@ class Wint2xMmaBase { Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = MatrixShape; // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = - Shape::kK / 4 + (Shape::kK + 127) / 128; + constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127) / 128; // code_scale float; code_zp float; super_scale ElementB constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 38fdcf9fec..2a0f22048c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -90,7 +90,7 @@ template < SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> -class Wint2xMmaMultistage : +class Wint2xMmaMultistage : public Wint2xMmaBase { public: ///< Base class @@ -282,12 +282,12 @@ class Wint2xMmaMultistage : { // Advance global iterators iterator_A.add_tile_offset({0, 1}); - //iterator_B.add_tile_offset({1, 0}); - tile_dequanter_B.AddTileOffset({1, 0}); + iterator_B.add_tile_offset({1, 0}); + //tile_dequanter_B.AddTileOffset({1, 0}); // Advance shared iterators smem_iterator_A_.add_tile_offset({0, 1}); - //smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_B_.add_tile_offset({1, 0}); // Increment shared memory write stage index ++smem_write_stage_idx_; @@ -295,7 +295,7 @@ class Wint2xMmaMultistage : if (smem_write_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - //smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); smem_write_stage_idx_ = 0; } } @@ -476,8 +476,11 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + copy_tiles_and_advance_per_stage_B(iterator_B); + + // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. + //tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, + // column_wise_smem_ptr_B_, stage); // Move to the next write stage advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); @@ -566,11 +569,11 @@ class Wint2xMmaMultistage : if (warp_mma_k + 1 == Base::kWarpGemmIterations) { // Unpack and dequant the first stage of B. int unpack_stage = stage - Base::kStages + 2; - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, unpack_stage); + //tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, + // column_wise_smem_ptr_B_, unpack_stage); // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); + //copy_tiles_and_advance_per_stage_B(iterator_B); } // Load the next warp-tile's B fragment from shared memory @@ -672,10 +675,11 @@ class Wint2xMmaMultistage : IteratorB &iterator_B, TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory { +#if 0 PipeState pipe_state; // Unpack and dequant the first stage of B. - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); + //tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); @@ -687,7 +691,7 @@ class Wint2xMmaMultistage : ++this->warp_tile_iterator_A_; // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); + //copy_tiles_and_advance_per_stage_B(iterator_B); // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); @@ -730,6 +734,7 @@ class Wint2xMmaMultistage : cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); + #endif } /// Prepares the class for another prologue. @@ -794,7 +799,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h index cec6bcea03..c44539fed1 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h @@ -55,6 +55,9 @@ struct TileDequanter { bool need_preload{true}; UnzipAndDequantFunctor unzip_functor; + CUTLASS_DEVICE + TileDequanter() {} + CUTLASS_DEVICE TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, const cutlass::MatrixCoord &extent, diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 356f305968..a328520322 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -52,6 +52,15 @@ namespace cutlass { namespace gemm { namespace kernel { +template std::string GetCutlassLayoutString() { + if (std::is_same::value) { + return "cutlass::layout::RowMajor"; + } else if (std::is_same::value) { + return "cutlass::layout::ColumnMajor"; + } + return "unknown"; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // This section exists to that we can use the same kernel code for regular gemm // and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma @@ -282,6 +291,27 @@ struct MoeFCGemm { platform::is_same::value) { assert(weight_scales); } + + CUTLASS_TRACE_HOST("[Arguments] problem_count: " << problem_count << ", threadblock_count: " << threadblock_count << ", gemm_n: " << gemm_n << ", gemm_k: " << gemm_k); + CUTLASS_TRACE_HOST("[Arguments] ptr_A: " << static_cast(ptr_A)); + CUTLASS_TRACE_HOST("[Arguments] ptr_B: " << static_cast(ptr_B)); + CUTLASS_TRACE_HOST("[Arguments] ptr_C: " << static_cast(ptr_C)); + CUTLASS_TRACE_HOST("[Arguments] ptr_D: " << static_cast(ptr_D)); + CUTLASS_TRACE_HOST("[Arguments] weight_scales: " << static_cast(weight_scales)); + CUTLASS_TRACE_HOST("[Arguments] total_rows_before_expert: " << static_cast(total_rows_before_expert)); + CUTLASS_TRACE_HOST("[Arguments] local_scale: " << static_cast(local_scale)); + CUTLASS_TRACE_HOST("[Arguments] code_scale: " << static_cast(code_scale)); + CUTLASS_TRACE_HOST("[Arguments] code_zp: " << static_cast(code_zp)); + CUTLASS_TRACE_HOST("[Arguments] quant_method: " << static_cast(quant_method)); + CUTLASS_TRACE_HOST("[Arguments] LayoutA: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] LayoutB: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] LayoutC: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorA::AccessType::kElements:" << Mma::IteratorA::AccessType::kElements); + CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorB::AccessType::kElements:" << Mma::IteratorB::AccessType::kElements); + CUTLASS_TRACE_HOST("[Arguments] SharedStorage Information:"); + CUTLASS_TRACE_HOST(" - ProblemVisitor::SharedStorage: " << sizeof(typename ProblemVisitor::SharedStorage) << " bytes"); + CUTLASS_TRACE_HOST(" - Mma::SharedStorage: " << sizeof(typename Mma::SharedStorage) << " bytes"); + CUTLASS_TRACE_HOST(" - Epilogue::SharedStorage: " << sizeof(typename Epilogue::SharedStorage) << " bytes"); } }; @@ -835,6 +865,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); + + if (problem_idx > 2) { + break; + } + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); // threadblock_offset of C @@ -879,16 +916,16 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::value ? gemm_n : gemm_k * kInterleave; - typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; + //typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; // the begin threadblock_offset of B, which holds the same column id with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; + //cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; - MmaElementB* smem_unzip_B_ptr = nullptr; + /*MmaElementB* smem_unzip_B_ptr = nullptr; if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); } @@ -901,7 +938,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(byte_ptr_B); // Compute position within threadblock int thread_idx = threadIdx.x; @@ -914,11 +953,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::Op; + CUTLASS_TRACE_HOST("Stages: " << Stages); + // Finally, set up the kernel. using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< ElementType, @@ -187,6 +189,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, "GroupedGEMM kernel"); } const int threadblock_count = multi_processor_count * occupancy; + CUTLASS_TRACE_HOST("kernel_occupancy: " << kernel_occupancy << ", occupancy: " << occupancy << ", threadblock_count: " << threadblock_count << ", multi_processor_count: " << multi_processor_count); typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); @@ -205,7 +208,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), + reinterpret_cast(B), reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), @@ -443,6 +446,7 @@ void dispatch_gemm_config(const T* A, #define dispatch_gemm_config_macro(AA, BB, CC, DD, EE, FF) \ case CutlassTileConfig:: \ CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \ + CUTLASS_TRACE_HOST("ThreadblockShape<" << AA << "," << BB << "," << CC << ">, WarpShape<" << DD << "," << EE << "," << FF << ">"); \ dispatch_gemm_config::value) { if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) { switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; @@ -563,16 +567,16 @@ void dispatch_moe_gemm_to_cutlass(const T* A, } else { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64); - dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64); - dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64); - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); - dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64); - dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64); - dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); - dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64); + //dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64); + //dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + //dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64); + //dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); + //dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; @@ -614,7 +618,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); + //dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); case CutlassTileConfig::Undefined: throw std::runtime_error( "[dispatch_moe_gemm_to_cutlass][SIMT] gemm config " diff --git a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu index fb9d2e69fe..55afa6d9fc 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu @@ -20,6 +20,8 @@ #include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" +#define _GROUP_GEMM_ONLY 0 + template void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, @@ -65,7 +67,11 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, reinterpret_cast(ffn1_weight.data()), reinterpret_cast(ffn1_super_scale ? ffn1_super_scale->data() : nullptr), reinterpret_cast(ffn1_bias ? ffn1_bias->data() : nullptr), +#if _GROUP_GEMM_ONLY + reinterpret_cast(ffn_out.data()), +#else reinterpret_cast(fc1_out.data()), +#endif const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, actual_total_rows, @@ -76,9 +82,12 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, "none", stream); +#if _GROUP_GEMM_ONLY + // do nothing +#else paddle::Tensor act_out; if (used_in_ep_low_latency) { - act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); + //act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); } else { act_out = paddle::experimental::swiglu(fc1_out, nullptr); } @@ -96,6 +105,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, num_experts, ffn2_quant_args, stream); +#endif } template @@ -198,7 +208,14 @@ paddle::Tensor MoeExpertFFNWint2Func( const bool used_in_ep_low_latency) { const auto dtype = permute_input.dtype(); +#if _GROUP_GEMM_ONLY + auto place = permute_input.place(); + int64_t expanded_active_expert_rows = permute_input.dims()[0]; + int64_t inter_size = ffn1_scale.get().dims()[1]; + auto ffn_out = GetEmptyTensor({expanded_active_expert_rows, inter_size}, dtype, place); +#else auto ffn_out = paddle::empty_like(permute_input, dtype); +#endif switch (dtype) { case paddle::DataType::BFLOAT16: @@ -289,7 +306,14 @@ std::vector> MoeExpertFFNWint2InferShape( const paddle::optional>& ffn2_code_zp_shape, const bool used_in_ep_low_latency) { +#if _GROUP_GEMM_ONLY + int64_t expanded_active_expert_rows = permute_input_shape[0]; + int64_t inter_size = ffn1_scale_shape.get()[1]; + std::cout << "expanded_active_expert_rows: " << expanded_active_expert_rows << ", inter_size: " << inter_size << std::endl; + return {std::vector{expanded_active_expert_rows, inter_size}}; +#else return {permute_input_shape}; +#endif } std::vector MoeExpertFFNWint2InferDtype( @@ -356,7 +380,7 @@ std::vector MoeExpertFFNWint2InferDtype( * Note: * - Low latency mode uses specialized grouped SwiGLU implementation */ -PD_BUILD_STATIC_OP(moe_expert_ffn_wint2) +PD_BUILD_OP(moe_expert_ffn_wint2) .Inputs({"permute_input", "tokens_expert_prefix_sum", "ffn1_weight", From 2efcfbbc61fd011d53220111707d71ea5613f5c9 Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Tue, 1 Jul 2025 19:46:15 +0800 Subject: [PATCH 02/11] Unify default_wint2x_mma. Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7 --- .../gemm/threadblock/default_mma.h | 76 +++------ .../gemm/threadblock/default_mma_bf16.h | 68 +++----- .../gemm/threadblock/default_wint2x_mma.h | 145 ++++++++++++++++++ 3 files changed, 186 insertions(+), 103 deletions(-) create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index 19a5e8fdaf..b50d66380e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -18,14 +18,12 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -378,38 +376,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -441,38 +424,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index c853532059..300261c3f0 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -19,7 +19,7 @@ #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" namespace cutlass { namespace gemm { @@ -379,38 +379,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -442,38 +427,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, uint2b_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h new file mode 100644 index 0000000000..0ae02016ff --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -0,0 +1,145 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +struct DefaultWint2xMma; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DefaultWint2xMma +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Element B must be uint2b_t"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass From feac9566a03411bb3925567de8d02ee55d788cc5 Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Wed, 2 Jul 2025 17:58:23 +0800 Subject: [PATCH 03/11] Change wint2 to ColumnMajorTileInterleave. Change-Id: I593cbe36f991c0c5044989d65f0014087587c624 --- .../gemm/kernel/mixed_gemm_B_layout.h | 16 ++++++-- .../gemm/threadblock/default_wint2x_mma.h | 39 +++++++++++++++---- .../gemm/threadblock/wint2x_mma_multistage.h | 36 +++++------------ .../gemm/warp/default_mma_tensor_op.h | 11 ++---- .../warp/mma_tensorop_compute_B_with_f16.h | 5 +++ .../fused_moe_gemm_kernels_template.h | 14 ++++--- custom_ops/gpu_ops/moe/moe_ffn_wint2.cu | 2 +- 7 files changed, 72 insertions(+), 51 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 167bf18cfc..8f61c6d9c4 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -133,10 +133,18 @@ template template struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 8; // at least 4-bytes - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; // 64 + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8 + +public: + // using Layout = layout::ColumnMajor; + // static constexpr int ElementsPerAccess = 16; // at least 4-bytes + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; // 64 + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h index 0ae02016ff..a67c8aa256 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -106,6 +106,9 @@ struct DefaultWint2xMma::value, "Element B must be uint2b_t"); + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; @@ -117,8 +120,8 @@ struct DefaultWint2xMma; + ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>; // Define iterators over tiles from the A operand using ThreadMapA = typename MmaCore::IteratorThreadMapA; @@ -127,17 +130,39 @@ struct DefaultWint2xMma, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; - // Define iterators over tiles from the B operand +private: + static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved"); + static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; + static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); + + using IteratorShapeB = MatrixShape< + MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>; + using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + ThreadMapB::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand using AccessTypeB = cutlass::Array; using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB, AccessTypeB>; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< + typename MmaCore::Shape, + IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, + ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 2a0f22048c..ed103b716d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -266,10 +266,9 @@ class Wint2xMmaMultistage : if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - // this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); smem_read_stage_idx_ = 0; } - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); } /// Advance global memory read-iterators and shared memory write-iterators to the stage @@ -566,16 +565,6 @@ class Wint2xMmaMultistage : this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - // Unpack and dequant the first stage of B. - int unpack_stage = stage - Base::kStages + 2; - //tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - // column_wise_smem_ptr_B_, unpack_stage); - - // Copy dequatized data to shared memory used by mma core. - //copy_tiles_and_advance_per_stage_B(iterator_B); - } - // Load the next warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); @@ -617,13 +606,10 @@ class Wint2xMmaMultistage : // global->shared fragment copies if (warp_mma_k < Base::kWarpGemmIterations - 1) { int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - - if (warp_mma_k == 0) { - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); - } + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); } // The second-to-last warp-tile also: @@ -632,8 +618,10 @@ class Wint2xMmaMultistage : if (warp_mma_k + 2 == Base::kWarpGemmIterations) { // Performs the last warp-tile's share of global->shared fragment copies int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); @@ -648,7 +636,7 @@ class Wint2xMmaMultistage : // Disable global fetching when done with global fetch iterations --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); + iterator_B.clear_mask(gemm_k_iterations == 0); } // The last warp-tile also converts the shared memory fragments used by @@ -675,12 +663,8 @@ class Wint2xMmaMultistage : IteratorB &iterator_B, TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory { -#if 0 PipeState pipe_state; - // Unpack and dequant the first stage of B. - //tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); - // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); @@ -690,9 +674,6 @@ class Wint2xMmaMultistage : this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); ++this->warp_tile_iterator_A_; - // Copy dequatized data to shared memory used by mma core. - //copy_tiles_and_advance_per_stage_B(iterator_B); - // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); @@ -709,6 +690,7 @@ class Wint2xMmaMultistage : pipe_state.tmp_accum_.clear(); } +#if 0 int stage = Base::kStages - 1; // Mainloop @@ -723,6 +705,7 @@ class Wint2xMmaMultistage : gemm_k_iterations, stage); stage += 1; + break; } if (Detail::kStagedAccumulation) { @@ -766,8 +749,7 @@ class Wint2xMmaMultistage : else { this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - //this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); } smem_read_stage_idx_ = smem_write_stage_idx_; } diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 350b247de2..af4298df5e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -41,12 +41,9 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +78,7 @@ struct DefaultMmaTensorOp::value; // Shape for loading the narrow data type from shared memory diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7c5088894b..ad1ca710e2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -295,6 +295,11 @@ class MmaTensorOpComputeBWithF16 assert(0); #endif } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index cbc9a2911d..b0314fa1e6 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -715,8 +715,8 @@ void MoeGemmRunner::run_gemm( std::vector candidate_configs = get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true); - static constexpr int warm_time = 5; - static constexpr int test_time = 10; + static constexpr int warm_time = 0; + static constexpr int test_time = 1; auto& gemmConfigManager = GemmConfigManager::Instance(); constexpr GemmDataType dtype = getGemmDataType(); constexpr GemmDataType wdtype = getGemmDataType(); @@ -735,8 +735,10 @@ void MoeGemmRunner::run_gemm( std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows), gemmConfigManager.getMaxProfileM()); bool find_one = false; - size_t num_candidate_configs_size = candidate_configs.size(); - for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) { + size_t num_candidate_configs_size = 2;//candidate_configs.size(); + // for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) + { + size_t ii = 1; try { for (int i = 0; i < warm_time; i++) { dispatch_to_arch(A, @@ -780,7 +782,7 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); check_cuda_error(cudaEventDestroy(start)); check_cuda_error(cudaEventDestroy(stop)); - //std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; + std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; if (elapsed < best_time) { best_id = ii; best_time = elapsed; @@ -801,6 +803,7 @@ void MoeGemmRunner::run_gemm( } } +#if 0 dispatch_to_arch(A, B, weight_scales, @@ -814,6 +817,7 @@ void MoeGemmRunner::run_gemm( quant_args_B, chosen_config, stream); +#endif } template diff --git a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu index 55afa6d9fc..5c8bbd6797 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu @@ -20,7 +20,7 @@ #include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" -#define _GROUP_GEMM_ONLY 0 +#define _GROUP_GEMM_ONLY 1 template void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, From da648e8a4fbdd992f5877e0a521f3a048f18c345 Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Wed, 2 Jul 2025 19:17:22 +0800 Subject: [PATCH 04/11] Enable async copy for B. Change-Id: Ia3ac37ad162a8cf3ccce4f268e81bd06c8ac3c46 --- .../gemm/threadblock/wint2x_mma_multistage.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index ed103b716d..0156c7e8a5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -475,7 +475,7 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_B(iterator_B); + copy_tiles_and_advance_per_stage_B(iterator_B); // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. //tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, @@ -609,7 +609,7 @@ class Wint2xMmaMultistage : int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); } // The second-to-last warp-tile also: @@ -621,7 +621,7 @@ class Wint2xMmaMultistage : int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); From 8c6fa14a541b88539bb14616e6f6e5b811211010 Mon Sep 17 00:00:00 2001 From: baoqiwen Date: Tue, 8 Jul 2025 11:00:25 +0800 Subject: [PATCH 05/11] Add wint2x Dequantizer --- .../gemm/threadblock/wint2x_mma_multistage.h | 1379 +++++++++-------- .../warp/mma_tensorop_wint2x_dequantizer.h | 696 +++++++++ 2 files changed, 1438 insertions(+), 637 deletions(-) create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 0156c7e8a5..a8c8dae4c8 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -36,8 +37,8 @@ #include "cutlass/aligned_buffer.h" #include "cutlass/arch/memory.h" -#include "cutlass/array.h" #include "cutlass/arch/memory_sm80.h" +#include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_shape.h" @@ -47,6 +48,8 @@ #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" #include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -90,605 +93,702 @@ template < SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> -class Wint2xMmaMultistage : - public Wint2xMmaBase { -public: - ///< Base class - using Base = Wint2xMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical - // accuracy, where each mainloop iteration first accumulates into a temporary - // set of freshly-cleared accumulators, which are subsequently added to the - // final accumulator set. - static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; - }; - - private: - - // Structure encapsulating pipeline state live from one iteration to the next - struct PipeState { - - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; +class Wint2xMmaMultistage : public Wint2xMmaBase { + public: + ///< Base class + using Base = Wint2xMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using LayoutScale = cutlass::layout::ColumnMajor; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using ElementB = typename WarpTransformedFragmentB::Element; + using Dequantizer = + warp::MmaTensorOpWin2xDequantizer; + + static_assert(sizeof(Dequantizer) > 0, + "Dequantizer template instantiation failed"); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved + // numerical accuracy, where each mainloop iteration first accumulates + // into a temporary set of freshly-cleared accumulators, which are + // subsequently added to the final accumulator set. + static bool const kStagedAccumulation = + arch::detail::UseStagedAccumulation::value; + }; + + private: + // Structure encapsulating pipeline state live from one iteration to the + // next + struct PipeState { + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = + typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = + typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math + /// instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + WarpTransformedFragmentA warp_transformed_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math + /// instructions + WarpLoadedFragmentB warp_loaded_frag_B_[2]; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + }; + + private: + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + // Wint2 unzip operator + Dequantizer warp_dequantizer_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + uint8_t *column_wise_smem_ptr_B_; + + uint8_t *smem_zipped_ptr_B_; + int smem_zipped_bytes_per_stage_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Wint2xMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) { + // Compute warp location within threadblock tile by mapping the warp_id + // to three coordinates: + // _m: the warp's position within the threadblock along the M + // dimension _n: the warp's position within the threadblock along the + // N dimension _k: the warp's position within the threadblock along + // the K dimension + + int warp_idx_mn = + warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + + column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); + + smem_zipped_ptr_B_ = + column_wise_smem_ptr_B_ + + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; + smem_zipped_bytes_per_stage_B_ = + Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; + } - /// Temporary accumulator to facilitate staged-accumulation - FragmentC tmp_accum_; - - /// Pair of A fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentA warp_loaded_frag_A_[2]; - WarpTransformedFragmentA warp_transformed_frag_A_[2]; - - /// Pair of B fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentB warp_loaded_frag_B_[2]; - WarpTransformedFragmentB warp_transformed_frag_B_[2]; - }; - - - private: - - // - // Data members - // - - /// Warp-level MMA operator - Operator warp_mma_; - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Shared memory write stage index - int smem_write_stage_idx_; - - /// Shared memory read stage index - int smem_read_stage_idx_; - - uint8_t* column_wise_smem_ptr_B_; - - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; - -public: - - /// Construct from tensor references - CUTLASS_DEVICE - Wint2xMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage &shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; - } - - /// Advance shared memory read-iterators to the next stage - CUTLASS_DEVICE - void advance_smem_read_stage() - { - ++smem_read_stage_idx_; - - if (smem_read_stage_idx_ == Base::kStages) { - // Wrap back around to the 'start' of the circular buffer in shared memory - this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); - smem_read_stage_idx_ = 0; + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared + // memory + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx_ = 0; + } } - } - - /// Advance global memory read-iterators and shared memory write-iterators to the stage - template - CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) - { - // Advance global iterators - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - //tile_dequanter_B.AddTileOffset({1, 0}); - - // Advance shared iterators - smem_iterator_A_.add_tile_offset({0, 1}); - smem_iterator_B_.add_tile_offset({1, 0}); - - // Increment shared memory write stage index - ++smem_write_stage_idx_; - - if (smem_write_stage_idx_ == Base::kStages) { - // Wrap back around to the 'start' of the circular buffer in shared memory - smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx_ = 0; + + /// Advance global memory read-iterators and shared memory write-iterators + /// to the stage + template + CUTLASS_DEVICE void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + TileDequanterB &tile_dequanter_B) { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + // tile_dequanter_B.AddTileOffset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared + // memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx_ = 0; + } } - } - - CUTLASS_DEVICE - void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; + CUTLASS_DEVICE + void copy_tiles_and_advance_A(IteratorA &iterator_A, + int group_start_A = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } } - - ++this->smem_iterator_A_; - } } - } - - template - CUTLASS_DEVICE - void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) { - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + template + CUTLASS_DEVICE void copy_tiles_and_advance_B(IteratorB &iterator_B, + int group_start_B = 0) { + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::copy( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch:: + copy_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch:: + copy( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } } - - ++this->smem_iterator_B_; - } + __syncthreads(); } - __syncthreads(); - } - CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) { - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + CUTLASS_DEVICE + void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) { + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - ++iterator_A; - } + ++iterator_A; + } - ++this->smem_iterator_A_; - } - } - - template - CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - if (InitStage) { - cutlass::arch::copy_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - } else { - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::copy( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } + ++this->smem_iterator_A_; } - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - __syncthreads(); - } - - /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching - /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - template - CUTLASS_DEVICE - void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining - { - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - - // Disable global fetching if done with global fetch iterations - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_A(iterator_A); - - // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_B(iterator_B); - - // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. - //tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - // column_wise_smem_ptr_B_, stage); - - // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); } - // Optionally clear the remaining stages of SMEM. This is a functional requirement for - // some kernels so that all accumulator elements outside the GEMM footprint are zero. - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - typename IteratorA::AccessType zero_A; - - zero_A.clear(); - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + template + CUTLASS_DEVICE void copy_tiles_and_advance_per_stage_B( + IteratorB &iterator_B) { + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); - typename IteratorA::AccessType *dst_ptr = - reinterpret_cast( - last_smem_iterator_A.get()); + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (InitStage) { + cutlass::arch:: + copy_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } else { + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch:: + copy_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch:: + copy( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + } + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + __syncthreads(); + } - *dst_ptr = zero_A; + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop + /// iterations + template + CUTLASS_DEVICE void prologue( + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB + &iterator_B, ///< [in|out] iterator over B operand in global memory + TileDequanterB &tile_dequanter_B, + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop + ///< iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - ++last_smem_iterator_A; - } + // Async copy zipped B to shared memory. + copy_tiles_and_advance_per_stage_A(iterator_A); - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; + // Async copy zipped B to shared memory. + copy_tiles_and_advance_per_stage_B(iterator_B); - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); + // TODO: Async copy other quantized params to shared memory, + // local_scale, code_scale, code_zp, super_scale. + // tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % + // Base::kStages) * smem_zipped_bytes_per_stage_B_, + // column_wise_smem_ptr_B_, stage); - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( - last_smem_iterator_B.get()); + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } - *dst_ptr = zero_B; + // Optionally clear the remaining stages of SMEM. This is a functional + // requirement for some kernels so that all accumulator elements outside + // the GEMM footprint are zero. + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + typename IteratorA::AccessType zero_A; + + zero_A.clear(); + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + } - ++last_smem_iterator_B; - } + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() { + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); } - } - - /// Wait until we have at least one completed global fetch stage - CUTLASS_DEVICE - void gmem_wait() - { - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - } - - /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - template - CUTLASS_DEVICE - void mac_loop_iter( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand - int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining - int stage) - { - // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k); - - // Load the next warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; - - // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); - } - - // Execute the current warp-tile of MMA operations - if (Detail::kStagedAccumulation) { - warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); - - if (warp_mma_k == 0) { - plus plus_accum; - accum = plus_accum(accum, pipe_state.tmp_accum_); - pipe_state.tmp_accum_.clear(); - } - } else { - warp_mma_( - accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); - } - - // Except for the last warp-tile, all warp-tiles issue their share of - // global->shared fragment copies - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); - } - - // The second-to-last warp-tile also: - // - performs the last warp-tile's share of global->shared fragment copies - // - moves to the next global fetch stage - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - // Performs the last warp-tile's share of global->shared fragment copies - int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - // Wait until we have at least one completed global fetch stage - gmem_wait(); + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + template + CUTLASS_DEVICE void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB + &iterator_B, ///< [in|out] iterator over B operand in global memory + TileDequanterB + &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand + int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop + ///< iterations remaining + int stage) { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, + // warp_mma_k); + + // Load the next warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load( + pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load( + pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B_; + + // Except for the first warp-tile, all warp-tiles convert their + // incoming shared memory fragments as necessary + if (warp_mma_k > 0) { + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_(pipe_state.tmp_accum_, + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_(accum, + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum); + } + + // Except for the last warp-tile, all warp-tiles issue their share + // of global->shared fragment copies + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A = + warp_mma_k * Detail::kAccessesPerGroupA; + int group_start_iteration_B = + warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, + group_start_iteration_B); + } + + // The second-to-last warp-tile also: + // - performs the last warp-tile's share of global->shared + // fragment copies + // - moves to the next global fetch stage + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Performs the last warp-tile's share of global->shared + // fragment copies + int group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + int group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async + // instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Move to the next global fetch stage + advance_smem_write_stage( + iterator_A, iterator_B, tile_dequanter_B); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch + // iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // The last warp-tile also converts the shared memory fragments used + // by the first warp-tile of the next iteration, if necessary (so we + // can immediately start issuing MMA instructions at the top of the + // loop ) + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + } + } + } - // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); - advance_smem_read_stage(); + /// Perform the specified number of threadblock mainloop iterations of + /// matrix multiply-accumulate. Assumes prologue has been initiated. + template + CUTLASS_DEVICE void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, + TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand + ///< in global memory + { + PipeState pipe_state; - // Disable global fetching when done with global fetch iterations - --gemm_k_iterations; + // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // The last warp-tile also converts the shared memory fragments used by - // the first warp-tile of the next iteration, if necessary (so we can - // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - } - } - } - - /// Perform the specified number of threadblock mainloop iterations of matrix - /// multiply-accumulate. Assumes prologue has been initiated. - template - CUTLASS_DEVICE - void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory - { - PipeState pipe_state; - - // Disable global fetching if done with global fetch iterations - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); - - // Load first warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); - ++this->warp_tile_iterator_A_; - - // Load first warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); - ++this->warp_tile_iterator_B_; - - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); - - if (Detail::kStagedAccumulation) { - pipe_state.tmp_accum_.clear(); - } + iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + ++this->warp_tile_iterator_B_; + + // // Transform, if necessary, the first warp-tile's shared memory + // fragments warp_mma_.transform( + // pipe_state.warp_transformed_frag_A_[0], + // pipe_state.warp_transformed_frag_B_[0], + // pipe_state.warp_loaded_frag_A_[0], + // pipe_state.warp_loaded_frag_B_[0]); + + __syncthreads(); // 确保所有线程执行到此处 + if (threadIdx.x == 0) { // 仅让一个线程打印,避免重复输出 + // printf("DEBUG: warp_loaded_frag_A_[0] values:\n"); + for (int i = 0; i < pipe_state.warp_loaded_frag_A_[0].size(); ++i) { + // 读取 fragment 中的元素 + auto val = pipe_state.warp_loaded_frag_A_[0][i]; + + // 以 16-bit 形式 reinterpret 为 uint16_t 查看原始位模式 + uint16_t bits = reinterpret_cast(&val)[0]; + + CUTLASS_TRACE_DEVICE( + " warp_loaded_frag_A_[%d] = 0x%04x", i, bits); + } + } + __syncthreads(); + + typename Dequantizer::FragmentLocalScale warp_frag_local_scale; + typename Dequantizer::FragmentCodeScale warp_frag_code_scale; + typename Dequantizer::FragmentCodeZp warp_frag_code_zp; + typename Dequantizer::FragmentSuperScale warp_frag_super_scale; + + typename Dequantizer::FragmentOutOperand warp_frag_out; + + CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load"); + warp_dequantizer_.load(warp_frag_local_scale, + warp_frag_code_scale, + warp_frag_code_zp, + warp_frag_super_scale); + __syncthreads(); + + CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant"); + warp_dequantizer_.dequantize(warp_frag_out, + pipe_state.warp_loaded_frag_B_[0], + warp_frag_local_scale, + warp_frag_code_scale, + warp_frag_code_zp, + warp_frag_super_scale); + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } #if 0 int stage = Base::kStages - 1; @@ -717,72 +817,77 @@ class Wint2xMmaMultistage : cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); - #endif - } - - /// Prepares the class for another prologue. - CUTLASS_DEVICE - void wind_down() - { - // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) - - // First, increment remaining warp tiles to get to the next full stage. (Ideally we would - // just decrement one tile, but not all iterators implement --() decrement.) - #pragma unroll - for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); - this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; +#endif } - smem_read_stage_idx_++; - // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) - static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; - if (smem_read_stage_idx_ > 1) - { - this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + /// Prepares the class for another prologue. + CUTLASS_DEVICE + void wind_down() { +// Catch-up the smem-read iterator to the smem-write iterator (so this class can +// be reused for another tile's prologue) + +// First, increment remaining warp tiles to get to the next full stage. (Ideally +// we would just decrement one tile, but not all iterators implement --() +// decrement.) +#pragma unroll + for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); + this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + smem_read_stage_idx_++; + + // Then wrap back two full stages (one for the tile advancing we just + // did, and one to catch the write iterators) + static const int kStageIters = + Policy::kPartitionsK * Base::kWarpGemmIterations; + if (smem_read_stage_idx_ > 1) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, (-2 * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset( + {(-2 * kStageIters), 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, ((Base::kStages - 2) * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset( + {((Base::kStages - 2) * kStageIters), 0}); + } + smem_read_stage_idx_ = smem_write_stage_idx_; } - else - { - this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); + + /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to + /// shared memory. + template + CUTLASS_DEVICE void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< pre-load and dequantize B to shared memory + TileDequanterB tile_dequanter_B, + ///< initial value of accumulator + FragmentC const &src_accum) { + // Prologue (start fetching iterations of global fragments into shared + // memory) + prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters( + gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); } - smem_read_stage_idx_ = smem_write_stage_idx_; - } - - /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. - template - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< pre-load and dequantize B to shared memory - TileDequanterB tile_dequanter_B, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); - - // Wait until we have at least one completed global fetch stage - gmem_wait(); - - // Initialize destination accumulators with source accumulators - accum = src_accum; - - // Perform the MAC-iterations - //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); - } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h new file mode 100644 index 0000000000..4d05740d53 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -0,0 +1,696 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include +#include "cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_ = WeightOnlyQuantOp::UNDEFINED, + /// + typename Enable = void> +class MmaTensorOpWin2xDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> + +class MmaTensorOpWin2xDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + bfloat16_t, + layout::ColumnMajor, + 32, + QuantOp_, + typename platform::enable_if= + 70>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementWeight = uint2b_t; + + /// Type of the scales + using ElementUnzipWeight = uint8_t; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Type of the scales + using ScaleComputeT = float; + + static constexpr int unzip_len = 4; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + using FragmentWeightOperand = + Array; + using FragmentOutOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentLocalScale = Array; + using FragmentCodeScale = Array; + using FragmentCodeZp = Array; + using FragmentSuperScale = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::ColumnMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = cutlass::TensorRef; + using TensorCodeRef = cutlass::TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, + TensorCodeRef smem_code_scale, + TensorCodeRef smem_code_zp, + TensorRef smem_super_scale, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_local_scale_ = smem_local_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + } + } + + // CUTLASS_DEVICE + // MmaTensorOpWin2xDequantizer() { + // pointer_local_scale_ = nullptr; + // pointer_code_scale_ = nullptr; + // pointer_code_zp_ = nullptr; + // if constexpr (hasZero(QuantOp)) { + // pointer_super_scale_ = nullptr; + // } + // } + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer() { + // Create fake pointer using a shared dummy buffer + CUTLASS_TRACE_DEVICE(" warp dequant aaa"); + + extern __shared__ char cutlass_fake_dequant_smem[]; + + // Memory layout (manual alignment): + // ElementScale (half or bf16): 2 bytes + // ScaleComputeT (float): 4 bytes + + pointer_local_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem); + pointer_code_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem + 64); + pointer_code_zp_ = + reinterpret_cast(cutlass_fake_dequant_smem + 128); + + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = reinterpret_cast( + cutlass_fake_dequant_smem + 192); + } + } + + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_TRACE_DEVICE(" warp dequant load"); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * + // InstructionShape::kN]; code_scale_frag[mma_n_iter] = + // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * + // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) + // { + // super_scale_frag[mma_n_iter] = + // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; + // } + // } + } + + CUTLASS_DEVICE + void dequantize(FragmentOutOperand& out_frag, + FragmentDequantizedOperand& operand_frag, + FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + CUTLASS_TRACE_DEVICE(" dequantize if def"); + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kPackNum = 4; + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kLocalScaleMask = 0xF; + static constexpr int32_t kBZP = 32; + + // using _MmaOperandB = typename ArchMmaOperator::FragmentB; + // using ExpandedMmaOperandB = Array; + // static_assert(ExpandedMmaOperandB::kElements * + // MmaOperator::MmaIterations::kColumn + // == FragmentDequantizedOperand::kElements, + // ""); + + // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // MmaOperator::MmaIterations::kColumn); + + // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" + // FragmentDequantizedOperand::kElements = %d ", + // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" + // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); + + // FragmentWeightOperand + CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", + FragmentWeightOperand::kElements); + // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", + // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight + // size = %d ", sizeof(ElementWeight)); + static_assert(std::is_same::value, + "B 是 uint8 量化类型"); + FragmentWeightOperand* weight_ptr = + reinterpret_cast(&operand_frag); + FragmentLocalScale* local_scale_ptr = + reinterpret_cast(&local_scale_frag); + FragmentCodeScale* code_scale_ptr = + reinterpret_cast(&code_scale_frag); + FragmentCodeZp* code_zp_ptr = + reinterpret_cast(&code_zp_frag); + FragmentSuperScale* super_scale_ptr = + reinterpret_cast(&super_scale_frag); + + ScaleComputeT code_scale = + static_cast(code_scale_ptr[0][0]); + ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); + ScaleComputeT super_scale = + static_cast(super_scale_ptr[0][0]); + int32_t local_scale = static_cast(local_scale_ptr[0][0]); + int32_t const shift_bits[4] = {9, 6, 3, 0}; + + ScaleComputeT zipped_value[16]; +#pragma unroll + for (int i = 0; i < 16; ++i) { + zipped_value[i] = static_cast(weight_ptr[0][i]); + } + + int local_scale_shift = 4; + int32_t shifted_local_scale = + (local_scale >> local_scale_shift) & kLocalScaleMask; + ScaleComputeT scale = + static_cast(shifted_local_scale) * super_scale; + +#pragma unroll + for (int i = 0; i < 16; ++i) { + int32_t decode_value = static_cast( + floor(zipped_value[i] * code_scale + code_zp + + static_cast(0.5))); + + int col = i * 4; + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { + int32_t shift_bit = shift_bits[shift_bit_id]; + int32_t shifted_value = + (decode_value >> shift_bit) & kWeightMask; + + ScaleComputeT value = + static_cast(shifted_value - kBZP); + out_frag[col + shift_bit_id] = + static_cast(scale * value); + } + } + + CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", + kColsPerMmaPerThread); + CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", + MmaOperator::MmaIterations::kColumn); + + // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 + // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = + // reinterpret_cast(&operand_frag); + + // printf("threadidx.x = %d\n", threadIdx.x); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + // __nv_bfloat162 scalex2 = + // __bfloat162bfloat162(scale_ptr[mma_n_iter]); + // __nv_bfloat162* operand_bf16x2_ptr = + // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + // CUTLASS_PRAGMA_UNROLL + // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + // { + // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], + // scalex2); + // } + // } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on + // older arch, scale conversion should happen before scales are stored + // to shared memory and we should use the fp16 dequantizer. This will + // avoid numerous conversion instructions in GEMM main loop. + CUTLASS_TRACE_DEVICE(" dequantize else def"); + // arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_local_scale_ += offset; + pointer_code_scale_ += offset; + pointer_code_zp_ += offset; + pointer_super_scale_ += offset; + } + + private: + ElementScale const* pointer_local_scale_; + ScaleComputeT const* pointer_code_scale_; + ScaleComputeT const* pointer_code_zp_; + ElementScale const* pointer_super_scale_; + + ElementScale const* pointer_out_; +}; + +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpWin2xDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::ColumnMajor, + 32, + QuantOp_, + typename platform::enable_if= + 70>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementWeight = uint2b_t; + + /// Type of the scales + using ElementUnzipWeight = uint8_t; + + /// Type of the scales + using ElementScale = half_t; + + /// Type of the scales + using ScaleComputeT = float; + + static constexpr int unzip_len = 4; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + using FragmentWeightOperand = + Array; + using FragmentOutOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentLocalScale = Array; + using FragmentCodeScale = Array; + using FragmentCodeZp = Array; + using FragmentSuperScale = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::ColumnMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = cutlass::TensorRef; + using TensorCodeRef = cutlass::TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, + TensorCodeRef smem_code_scale, + TensorCodeRef smem_code_zp, + TensorRef smem_super_scale, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_local_scale_ = smem_local_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + } + } + + // CUTLASS_DEVICE + // MmaTensorOpWin2xDequantizer() { + // pointer_local_scale_ = nullptr; + // pointer_code_scale_ = nullptr; + // pointer_code_zp_ = nullptr; + // if constexpr (hasZero(QuantOp)) { + // pointer_super_scale_ = nullptr; + // } + // } + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer() { + // Create fake pointer using a shared dummy buffer + CUTLASS_TRACE_DEVICE(" warp dequant aaa"); + + extern __shared__ char cutlass_fake_dequant_smem[]; + + // Memory layout (manual alignment): + // ElementScale (half or bf16): 2 bytes + // ScaleComputeT (float): 4 bytes + + pointer_local_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem); + pointer_code_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem + 64); + pointer_code_zp_ = + reinterpret_cast(cutlass_fake_dequant_smem + 128); + + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = reinterpret_cast( + cutlass_fake_dequant_smem + 192); + } + } + + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_TRACE_DEVICE(" warp dequant load"); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * + // InstructionShape::kN]; code_scale_frag[mma_n_iter] = + // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * + // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) + // { + // super_scale_frag[mma_n_iter] = + // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; + // } + // } + } + + CUTLASS_DEVICE + void dequantize(FragmentOutOperand& out_frag, + FragmentDequantizedOperand& operand_frag, + FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + CUTLASS_TRACE_DEVICE(" dequantize if def"); + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kPackNum = 4; + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kLocalScaleMask = 0xF; + static constexpr int32_t kBZP = 32; + + // using _MmaOperandB = typename ArchMmaOperator::FragmentB; + // using ExpandedMmaOperandB = Array; + // static_assert(ExpandedMmaOperandB::kElements * + // MmaOperator::MmaIterations::kColumn + // == FragmentDequantizedOperand::kElements, + // ""); + + // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // MmaOperator::MmaIterations::kColumn); + + // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" + // FragmentDequantizedOperand::kElements = %d ", + // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" + // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); + + // FragmentWeightOperand + CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", + FragmentWeightOperand::kElements); + // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", + // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight + // size = %d ", sizeof(ElementWeight)); + static_assert(std::is_same::value, + "B 是 uint8 量化类型"); + FragmentWeightOperand* weight_ptr = + reinterpret_cast(&operand_frag); + FragmentLocalScale* local_scale_ptr = + reinterpret_cast(&local_scale_frag); + FragmentCodeScale* code_scale_ptr = + reinterpret_cast(&code_scale_frag); + FragmentCodeZp* code_zp_ptr = + reinterpret_cast(&code_zp_frag); + FragmentSuperScale* super_scale_ptr = + reinterpret_cast(&super_scale_frag); + + ScaleComputeT code_scale = + static_cast(code_scale_ptr[0][0]); + ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); + ScaleComputeT super_scale = + static_cast(super_scale_ptr[0][0]); + int32_t local_scale = static_cast(local_scale_ptr[0][0]); + int32_t const shift_bits[4] = {9, 6, 3, 0}; + + ScaleComputeT zipped_value[16]; +#pragma unroll + for (int i = 0; i < 16; ++i) { + zipped_value[i] = static_cast(weight_ptr[0][i]); + } + + int local_scale_shift = 4; + int32_t shifted_local_scale = + (local_scale >> local_scale_shift) & kLocalScaleMask; + ScaleComputeT scale = + static_cast(shifted_local_scale) * super_scale; + +#pragma unroll + for (int i = 0; i < 16; ++i) { + int32_t decode_value = static_cast( + floor(zipped_value[i] * code_scale + code_zp + + static_cast(0.5))); + + int col = i * 4; + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { + int32_t shift_bit = shift_bits[shift_bit_id]; + int32_t shifted_value = + (decode_value >> shift_bit) & kWeightMask; + + ScaleComputeT value = + static_cast(shifted_value - kBZP); + out_frag[col + shift_bit_id] = + static_cast(scale * value); + } + } + + CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", + kColsPerMmaPerThread); + CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", + MmaOperator::MmaIterations::kColumn); + + // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 + // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = + // reinterpret_cast(&operand_frag); + + // printf("threadidx.x = %d\n", threadIdx.x); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + // __nv_bfloat162 scalex2 = + // __bfloat162bfloat162(scale_ptr[mma_n_iter]); + // __nv_bfloat162* operand_bf16x2_ptr = + // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + // CUTLASS_PRAGMA_UNROLL + // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + // { + // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], + // scalex2); + // } + // } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on + // older arch, scale conversion should happen before scales are stored + // to shared memory and we should use the fp16 dequantizer. This will + // avoid numerous conversion instructions in GEMM main loop. + CUTLASS_TRACE_DEVICE(" dequantize else def"); + // arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_local_scale_ += offset; + pointer_code_scale_ += offset; + pointer_code_zp_ += offset; + pointer_super_scale_ += offset; + } + + private: + ElementScale const* pointer_local_scale_; + ScaleComputeT const* pointer_code_scale_; + ScaleComputeT const* pointer_code_zp_; + ElementScale const* pointer_super_scale_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// From 5ce342431ad3140feb7a83ac8e47fdb71dd0aafd Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Tue, 8 Jul 2025 13:39:22 +0800 Subject: [PATCH 06/11] Remove TileDequanterB related codes. Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8 --- .../gemm/threadblock/wint2x_mma_multistage.h | 30 +--- .../gemm/threadblock/wint2x_tile_dequanter.h | 133 ------------------ .../moe_gemm/fused_moe_cutlass_kernel.h | 22 --- 3 files changed, 6 insertions(+), 179 deletions(-) delete mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 0156c7e8a5..546e3ca9f4 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -45,7 +45,6 @@ #include "cutlass_extensions/arch/memory_copy_sm80.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -272,17 +271,12 @@ class Wint2xMmaMultistage : } /// Advance global memory read-iterators and shared memory write-iterators to the stage - template CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) + void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); iterator_B.add_tile_offset({1, 0}); - //tile_dequanter_B.AddTileOffset({1, 0}); // Advance shared iterators smem_iterator_A_.add_tile_offset({0, 1}); @@ -455,12 +449,10 @@ class Wint2xMmaMultistage : /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - template CUTLASS_DEVICE void prologue( IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages @@ -478,11 +470,9 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_B(iterator_B); // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. - //tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - // column_wise_smem_ptr_B_, stage); // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); @@ -544,14 +534,12 @@ class Wint2xMmaMultistage : } /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - template CUTLASS_DEVICE void mac_loop_iter( PipeState &pipe_state, ///< [in|out] loop-carried pipeline state FragmentC &accum, ///< [in|out] destination accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining int stage) { @@ -630,7 +618,7 @@ class Wint2xMmaMultistage : gmem_wait(); // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); advance_smem_read_stage(); // Disable global fetching when done with global fetch iterations @@ -654,14 +642,12 @@ class Wint2xMmaMultistage : /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. - template CUTLASS_DEVICE void gemm_iters( int gemm_k_iterations, ///< number of threadblock mainloop iterations FragmentC &accum, ///< [in|out] accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory + IteratorB &iterator_B) { PipeState pipe_state; @@ -701,7 +687,6 @@ class Wint2xMmaMultistage : accum, iterator_A, iterator_B, - tile_dequanter_B, gemm_k_iterations, stage); stage += 1; @@ -755,7 +740,6 @@ class Wint2xMmaMultistage : } /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. - template CUTLASS_DEVICE void operator()( ///< problem size of GEMM @@ -766,13 +750,11 @@ class Wint2xMmaMultistage : IteratorA iterator_A, ///< iterator over B operand in global memory IteratorB iterator_B, - ///< pre-load and dequantize B to shared memory - TileDequanterB tile_dequanter_B, ///< initial value of accumulator FragmentC const &src_accum) { // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + prologue(iterator_A, iterator_B, gemm_k_iterations); // Wait until we have at least one completed global fetch stage gmem_wait(); @@ -781,7 +763,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h deleted file mode 100644 index c44539fed1..0000000000 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "cutlass/gemm_coord.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -template -struct TileDequanter { - using WeightQuantTraits = WintQuantTraits; - using MmaElementT = typename WeightQuantTraits::MmaWeightType; - using QuantArguments = typename WeightQuantTraits::Arguments; - - using UnzipAndDequantFunctor = - UnzipAndDequantFunctor; - - static constexpr bool kUseSharedMemory = true; - - static constexpr int kRows = Rows; - static constexpr int kColumns = Columns; - static constexpr int kStages = Stages; - - MmaElementT *out_smem_ptr{nullptr}; - - char *pointer{nullptr}; - int64_t ldm{0}; - cutlass::MatrixCoord tb_offset; - cutlass::MatrixCoord extent; - - ScaleElementT *super_scale_ptr{nullptr}; - cutlass::MatrixCoord tb_offset_scale; - - QuantArguments quant_args; - - int64_t block_start_rows[kStages]; - bool need_preload{true}; - UnzipAndDequantFunctor unzip_functor; - - CUTLASS_DEVICE - TileDequanter() {} - - CUTLASS_DEVICE - TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, - const cutlass::MatrixCoord &extent, - const cutlass::MatrixCoord &tb_offset, - ScaleElementT *super_scale_ptr, - const cutlass::MatrixCoord &tb_offset_scale, - const QuantArguments &quant_args) - : out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent), - tb_offset(tb_offset), super_scale_ptr(super_scale_ptr), - tb_offset_scale(tb_offset_scale), quant_args(quant_args) {} - - CUTLASS_DEVICE - MmaElementT *GetOutPtr() { return out_smem_ptr; } - - CUTLASS_DEVICE - void AddTileOffset(const cutlass::MatrixCoord &tile_offset) { - tb_offset.row() += tile_offset.row() * kRows; - tb_offset.column() += tile_offset.column() * kColumns; - tb_offset_scale.column() += tile_offset.column() * kColumns; - } - - CUTLASS_DEVICE - void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row()); - if (tb_offset.row() >= extent.row() || - tb_offset.column() >= extent.column()) { - return; - } - - block_start_rows[stage % kStages] = tb_offset.row(); - - using ZippedT = typename WeightQuantTraits::WeightType; - ZippedT *in_ptr = reinterpret_cast(pointer) + zipped_row * ldm + - tb_offset.column(); - ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column(); - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - const uint8_t *local_scale_ptr = quant_args.local_scale_ptr + - (tb_offset.row() / 128) * ldm + - tb_offset_scale.column(); - const float *code_scale_ptr = - quant_args.code_scale_ptr + tb_offset_scale.column(); - const float *code_zp_ptr = - quant_args.code_zp_ptr + tb_offset_scale.column(); - - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr, - scale_ptr, &args, ldm, need_preload); - need_preload = false; - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } - - CUTLASS_DEVICE - void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int64_t block_start_row = block_start_rows[stage % kStages]; - if (block_start_row >= extent.row()) { - return; - } - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row); - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index a328520322..ce20bcaaec 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -43,7 +43,6 @@ #include "cutlass/trace.h" #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" #include "cutlass_extensions/tile_interleaved_layout.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -844,9 +843,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 1, "B must be row major/col major OR col major interleaved."); - // LayoutB should be RowMajor - using TileDequanterB = cutlass::gemm::threadblock::TileDequanter; - // // Problem visitor. // @@ -916,30 +912,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::value ? gemm_n : gemm_k * kInterleave; - //typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; // the begin threadblock_offset of B, which holds the same column id with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - //cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; - /*MmaElementB* smem_unzip_B_ptr = nullptr; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); - } - QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n); - TileDequanterB tile_dequanter_B(smem_unzip_B_ptr, - byte_ptr_B, - ldm_B, - extent_B, - tb_offset_B, - weight_scale_ptr, - tb_offset_scale, - quant_args); - MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();*/ - TileDequanterB tile_dequanter_B; ElementB* ptr_B = reinterpret_cast(byte_ptr_B); // Compute position within threadblock @@ -989,7 +968,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm Date: Tue, 8 Jul 2025 16:39:47 +0800 Subject: [PATCH 07/11] Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca --- .../gemm/threadblock/default_wint2x_mma.h | 6 +- .../gemm/threadblock/wint2x_mma_base.h | 38 ++-- .../gemm/threadblock/wint2x_mma_multistage.h | 169 ++++++++++++++++-- .../interleaved_numeric_conversion.h | 108 ++++++++++- .../moe_gemm/fused_moe_cutlass_kernel.h | 9 +- 5 files changed, 285 insertions(+), 45 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h index a67c8aa256..72c22a175f 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -19,6 +19,7 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/interleaved_numeric_conversion.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" namespace cutlass { @@ -156,13 +157,16 @@ struct DefaultWint2xMma; + using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< + ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>; + typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 7dec56be29..cdb465c60c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -93,6 +93,15 @@ class Wint2xMmaBase { static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + /// Number of warp-level GEMM oeprations per load for B + static constexpr int kWarpGemmIterationsPerLoadForB = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); + + static constexpr int kWarpLoadIterationsForB = + kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; + + /// Number of stages static int const kStages = Stages; @@ -131,16 +140,16 @@ class Wint2xMmaBase { using ShapeB = MatrixShape; - // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127) / 128; + // local_scale uint4 + constexpr static int kGroupWiseParamRows = Shape::kK / 64; + + using GroupWiseParamShapeB = MatrixShape; // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + + constexpr static int kColumnWiseParamRows = 2 * sizeof(float) + sizeof_bits::value / 8; - using ZippedShapeB = MatrixShape; - - using NopaddingShapeB = MatrixShape; + using ColumnWiseParamShapeB = MatrixShape; public: // @@ -153,12 +162,11 @@ class Wint2xMmaBase { /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for quanted B operand - AlignedBuffer operand_zipped_B; + /// Buffer for local_scale of B operand + AlignedBuffer operand_local_scale_B; - /// Buffer for unzip B operand - AlignedBuffer - operand_unzip_B; + /// Buffer for column-wise params of B operand + AlignedBuffer operand_column_wise_B; public: // @@ -188,14 +196,6 @@ class Wint2xMmaBase { TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - - CUTLASS_HOST_DEVICE - uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } - - CUTLASS_HOST_DEVICE - typename Operator::ElementB *operand_unzip_B_ptr() { - return operand_unzip_B.data(); - } }; protected: diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 1038f4220b..ca63f8c1d6 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -86,10 +86,10 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Transform for input B applied in register after the LDS + typename TransformBAfterLDS_, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> class Wint2xMmaMultistage : public Wint2xMmaBase { public: @@ -107,8 +107,10 @@ class Wint2xMmaMultistage : using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; + /// Transform for input B applied in register after the LDS + using TransformBAfterLDS = TransformBAfterLDS_; - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; using SmemIteratorA = SmemIteratorA_; using SmemIteratorB = SmemIteratorB_; @@ -131,12 +133,11 @@ class Wint2xMmaMultistage : using LayoutScale = cutlass::layout::ColumnMajor; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - using ElementB = typename WarpTransformedFragmentB::Element; using Dequantizer = warp::MmaTensorOpWin2xDequantizer; @@ -199,6 +200,14 @@ class Wint2xMmaMultistage : WarpTransformedFragmentB warp_transformed_frag_B_[2]; }; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool IsTileInterleaveLayout = + layout::IsColumnMajorTileInterleave::value; + static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); private: @@ -224,10 +233,11 @@ class Wint2xMmaMultistage : /// Shared memory read stage index int smem_read_stage_idx_; - uint8_t* column_wise_smem_ptr_B_; + /// Transform for B in register + TransformBAfterLDS transform_B_; - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; + uint8_t* smem_ptr_B_; + uint8_t* ptr_B_; public: @@ -261,16 +271,31 @@ class Wint2xMmaMultistage : int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d", + Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave); + CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d", + Policy::kPartitionsK, Base::kWarpGemmIterations, + Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k); + // Add per-warp offsets in units of warp-level tiles this->warp_tile_iterator_A_.add_tile_offset( {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); this->warp_tile_iterator_B_.add_tile_offset( {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; + CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}", + Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn); + CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d", + shared_storage.operand_A.data(), static_cast(Base::SharedStorage::ShapeA::kRow), + static_cast(Base::SharedStorage::ShapeA::kColumn)); + CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVector=%d", + shared_storage.operand_B.data(), + static_cast(Base::SharedStorage::ShapeB::kRow), static_cast(Base::SharedStorage::ShapeB::kColumn), + static_cast(sizeof(shared_storage.operand_B)), + static_cast(IteratorB::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorB::AccessType)), + static_cast(Detail::AsyncCopyIterationsPerStageB), static_cast(IteratorB::kAccessesPerVector)); + + smem_ptr_B_ = reinterpret_cast(shared_storage.operand_B.data()); } /// Advance shared memory read-iterators to the next stage @@ -371,6 +396,13 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); + if (group_start_B == 0 && j == 0 && v == 0) { + CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d", + reinterpret_cast(dst_ptr), reinterpret_cast(gmem_ptr), + static_cast(Detail::kAccessesPerGroupB), static_cast(IteratorB::kAccessesPerVector), + static_cast(sizeof(typename IteratorB::Element))); + } + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { cutlass::arch::copy_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); @@ -423,7 +455,7 @@ class Wint2xMmaMultistage : template CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B, int stage) { iterator_B.set_iteration_index(0); this->smem_iterator_B_.set_iteration_index(0); @@ -443,6 +475,31 @@ class Wint2xMmaMultistage : IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + if (v == 0) { + int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_B_); + int gmem_k = 8192 * kInterleave / 4; + int gmem_n = 1792 / kInterleave; + int gmem_row = gmem_offset / gmem_k; + int gmem_col = gmem_offset % gmem_k; + + int smem_offset = reinterpret_cast(dst_ptr) - reinterpret_cast(smem_ptr_B_); + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + int smem_row = smem_offset / smem_k; + int smem_col = smem_offset % smem_k; + + uint8_t* gmem_uint8_ptr = reinterpret_cast(gmem_ptr); + + CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};", + stage, reinterpret_cast(gmem_ptr), reinterpret_cast(dst_ptr), kSrcBytes, + gmem_n, gmem_k, gmem_row, gmem_col, + static_cast(gmem_uint8_ptr[0]), static_cast(gmem_uint8_ptr[1]), + static_cast(gmem_uint8_ptr[2]), static_cast(gmem_uint8_ptr[3]), + static_cast(gmem_uint8_ptr[4]), static_cast(gmem_uint8_ptr[5]), + static_cast(gmem_uint8_ptr[6]), static_cast(gmem_uint8_ptr[7]), + smem_row, smem_col); + } + if (InitStage) { cutlass::arch::copy_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); @@ -484,7 +541,7 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_B(iterator_B); + copy_tiles_and_advance_per_stage_B(iterator_B, stage); // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. @@ -666,6 +723,18 @@ class Wint2xMmaMultistage : IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B) { +#if 0 + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + for (int i = 0; i < 3 * smem_n; ++i) { + for (int j = 0; j < smem_k; ++j) { + if (i % 3 == 0) { + CUTLASS_TRACE_DEVICE(" [i=%d, j=%d, %dx%d] %d", i, j, smem_n, smem_k, static_cast(smem_ptr_B_[i * smem_k + j])); + } + } + } +#endif + PipeState pipe_state; // Disable global fetching if done with global fetch iterations @@ -682,6 +751,70 @@ class Wint2xMmaMultistage : this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); ++this->warp_tile_iterator_B_; + if (PipeState::WarpLoadedFragmentA::kElements == 8) { + ElementA* warp_frag_A_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_A_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes", + static_cast(warp_frag_A_ptr[0]), static_cast(warp_frag_A_ptr[1]), + static_cast(warp_frag_A_ptr[2]), static_cast(warp_frag_A_ptr[3]), + static_cast(warp_frag_A_ptr[4]), static_cast(warp_frag_A_ptr[5]), + static_cast(warp_frag_A_ptr[6]), static_cast(warp_frag_A_ptr[7]), + sizeof_bits::value / 8); + } + if (PipeState::WarpLoadedFragmentB::kElements == 64) { + uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", + static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), + static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), + static_cast(reg_uint8_ptr[4]), static_cast(reg_uint8_ptr[5]), + static_cast(reg_uint8_ptr[6]), static_cast(reg_uint8_ptr[7]), + static_cast(reg_uint8_ptr[8]), static_cast(reg_uint8_ptr[9]), + static_cast(reg_uint8_ptr[10]), static_cast(reg_uint8_ptr[11]), + static_cast(reg_uint8_ptr[12]), static_cast(reg_uint8_ptr[13]), + static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), + sizeof_bits::value / 8); + } + + typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_(pipe_state.warp_loaded_frag_B_[0]); + if (TransformBAfterLDS::result_type::kElements == 64) { + CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits::value / 8); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[0]), static_cast(unpacked_frag_B[1]), + static_cast(unpacked_frag_B[2]), static_cast(unpacked_frag_B[3]), + static_cast(unpacked_frag_B[4]), static_cast(unpacked_frag_B[5]), + static_cast(unpacked_frag_B[6]), static_cast(unpacked_frag_B[7]), + static_cast(unpacked_frag_B[8]), static_cast(unpacked_frag_B[9]), + static_cast(unpacked_frag_B[10]), static_cast(unpacked_frag_B[11]), + static_cast(unpacked_frag_B[12]), static_cast(unpacked_frag_B[13]), + static_cast(unpacked_frag_B[14]), static_cast(unpacked_frag_B[15])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[16]), static_cast(unpacked_frag_B[17]), + static_cast(unpacked_frag_B[18]), static_cast(unpacked_frag_B[19]), + static_cast(unpacked_frag_B[20]), static_cast(unpacked_frag_B[21]), + static_cast(unpacked_frag_B[22]), static_cast(unpacked_frag_B[23]), + static_cast(unpacked_frag_B[24]), static_cast(unpacked_frag_B[25]), + static_cast(unpacked_frag_B[26]), static_cast(unpacked_frag_B[27]), + static_cast(unpacked_frag_B[28]), static_cast(unpacked_frag_B[29]), + static_cast(unpacked_frag_B[30]), static_cast(unpacked_frag_B[31])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[32]), static_cast(unpacked_frag_B[33]), + static_cast(unpacked_frag_B[34]), static_cast(unpacked_frag_B[35]), + static_cast(unpacked_frag_B[36]), static_cast(unpacked_frag_B[37]), + static_cast(unpacked_frag_B[38]), static_cast(unpacked_frag_B[39]), + static_cast(unpacked_frag_B[40]), static_cast(unpacked_frag_B[41]), + static_cast(unpacked_frag_B[42]), static_cast(unpacked_frag_B[43]), + static_cast(unpacked_frag_B[44]), static_cast(unpacked_frag_B[45]), + static_cast(unpacked_frag_B[46]), static_cast(unpacked_frag_B[47])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[48]), static_cast(unpacked_frag_B[49]), + static_cast(unpacked_frag_B[50]), static_cast(unpacked_frag_B[51]), + static_cast(unpacked_frag_B[52]), static_cast(unpacked_frag_B[53]), + static_cast(unpacked_frag_B[54]), static_cast(unpacked_frag_B[55]), + static_cast(unpacked_frag_B[56]), static_cast(unpacked_frag_B[57]), + static_cast(unpacked_frag_B[58]), static_cast(unpacked_frag_B[59]), + static_cast(unpacked_frag_B[60]), static_cast(unpacked_frag_B[61]), + static_cast(unpacked_frag_B[62]), static_cast(unpacked_frag_B[63])); + } + typename Dequantizer::FragmentLocalScale warp_frag_local_scale; typename Dequantizer::FragmentCodeScale warp_frag_code_scale; typename Dequantizer::FragmentCodeZp warp_frag_code_zp; @@ -702,6 +835,7 @@ class Wint2xMmaMultistage : warp_frag_code_zp, warp_frag_super_scale); +#if 0 // Transform, if necessary, the first warp-tile's shared memory fragments warp_mma_.transform( pipe_state.warp_transformed_frag_A_[0], @@ -713,7 +847,6 @@ class Wint2xMmaMultistage : pipe_state.tmp_accum_.clear(); } -#if 0 int stage = Base::kStages - 1; // Mainloop @@ -790,6 +923,8 @@ class Wint2xMmaMultistage : ///< initial value of accumulator FragmentC const &src_accum) { + ptr_B_ = reinterpret_cast(iterator_B.get_origin_pointer()); + // Prologue (start fetching iterations of global fragments into shared memory) prologue(iterator_A, iterator_B, gemm_k_iterations); @@ -800,7 +935,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680e..1f5584862b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -39,18 +39,16 @@ #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" +#include "cutlass/trace.h" -namespace cutlass -{ +namespace cutlass { // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; +struct FastInterleavedAndBiasedNumericArrayConverter; template <> struct FastInterleavedAndBiasedNumericArrayConverter @@ -440,6 +438,106 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = T; + + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kBZP = 32; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + uint8_t const* in_ptr = reinterpret_cast(&source); + + ScaleComputeT code_scale = static_cast(1); + ScaleComputeT code_zp = static_cast(0); + ScaleComputeT floor_offset = static_cast(0.5); + + CUTLASS_TRACE_DEVICE(" source: [%d, %d, %d, %d]", + static_cast(in_ptr[0]), static_cast(in_ptr[1]), + static_cast(in_ptr[2]), static_cast(in_ptr[3])); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + int32_t decode_value = + static_cast(floor(static_cast(in_ptr[i]) * code_scale + code_zp + floor_offset)); + + ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + + result[0] = static_cast(value_0); + result[1] = static_cast(value_1); + result[2] = static_cast(value_2); + result[3] = static_cast(value_3); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index ce20bcaaec..0d9aa62b3f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -861,13 +861,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); - if (problem_idx > 2) { break; } + CUTLASS_TRACE_DEVICE(" problem_idx: %d, cta_idx: %d, problem_size: {%d, %d, %d}", + problem_idx, cta_idx, static_cast(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); // threadblock_offset of C @@ -919,6 +919,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(ldm_B), tb_offset_B.row(), tb_offset_B.column(), extent_B.row(), extent_B.column()); + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); // Compute position within threadblock From e86c13de135ed9df6af6b79275504f55e1268762 Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Thu, 10 Jul 2025 10:30:00 +0800 Subject: [PATCH 08/11] Implement Wint2ParamsAccessor to load extra quant params from global memory. Change-Id: Ic3750cd9b767df8893501820880c3342a4b47233 --- .../gemm/threadblock/default_wint2x_mma.h | 87 +- .../gemm/threadblock/wint2x_mma_base.h | 24 +- .../gemm/threadblock/wint2x_mma_multistage.h | 350 ++++---- .../gemm/threadblock/wint2x_params_accessor.h | 326 ++++++++ .../warp/mma_tensorop_wint2x_dequantizer.h | 747 +++++------------- .../interleaved_numeric_conversion.h | 2 + .../cutlass_extensions/wint_type_traits.h | 9 +- .../moe_gemm/fused_moe_cutlass_kernel.h | 83 +- custom_ops/gpu_ops/moe/moe_ffn_wint2.cu | 13 +- 9 files changed, 895 insertions(+), 746 deletions(-) create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h index 72c22a175f..8f4b1efefb 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -18,9 +18,9 @@ #pragma once #include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" namespace cutlass { namespace gemm { @@ -28,6 +28,55 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// +template +struct DefaultQuantParamsIterators { +private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize; + static constexpr int kColumns = ThreadblockShape::kN; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, kAlignment>; + +public: + using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< + MatrixShape, ElementT, layout::RowMajor, 0, + IteratorThreadMap, kAlignment>; + using SmemIterator = Iterator; + + //using AccessType = cutlass::Array; + //using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator< + // MatrixShape, ElementT, layout::RowMajor, + // 0, IteratorThreadMap, AccessType>; +}; + +template +struct DefaultQuantParamsIterators { +private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize); + static constexpr int kColumns = + (GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, kAlignment>; + +public: + using Iterator = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, uint4b_t, + layout::RowMajor, 0, IteratorThreadMap, kAlignment>; + using SmemIterator = Iterator; +}; + template < /// Element type for A matrix operand typename ElementA_, @@ -100,7 +149,7 @@ struct DefaultWint2xMma { - +public: static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); @@ -110,6 +159,12 @@ struct DefaultWint2xMma::value, "Mma multistage must dequantize after ldsm"); + using ElementSuperScale = ElementA; + using ElementLocalScale = uint4b_t; + using ElementCodeScaleZp = float; + + static constexpr int kGroupSize = 64; + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; @@ -157,16 +212,36 @@ struct DefaultWint2xMma; - using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< - ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>; +private: + // Define iterators over tiles from extra quant params for B operand + using IteratorSuperScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementSuperScale, -1>::Iterator; + using SmemIteratorSuperScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementSuperScale, -1>::SmemIterator; + + using IteratorLocalScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator; + using SmemIteratorLocalScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator; + + using IteratorCodeScaleZp = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + +public: + using QuantParamsAccessor = Wint2ParamsAccessor< + ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale, + IteratorLocalScale, SmemIteratorLocalScale, + IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, - ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>; + ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + kStages, QuantParamsAccessor, SharedMemoryClear>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index cdb465c60c..4b7d3ac06e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -63,8 +63,8 @@ template < typename Policy_, /// Number of stages, int Stages, - /// Used for partial specialization - typename Enable = bool> + /// Size of extra quantized params + typename QuantParamsShape> class Wint2xMmaBase { public: ///< Size of the Gemm problem - concept: gemm::GemmShape<> @@ -101,7 +101,6 @@ class Wint2xMmaBase { static constexpr int kWarpLoadIterationsForB = kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; - /// Number of stages static int const kStages = Stages; @@ -140,16 +139,8 @@ class Wint2xMmaBase { using ShapeB = MatrixShape; - // local_scale uint4 - constexpr static int kGroupWiseParamRows = Shape::kK / 64; - - using GroupWiseParamShapeB = MatrixShape; - - // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamRows = 2 * sizeof(float) + - sizeof_bits::value / 8; - - using ColumnWiseParamShapeB = MatrixShape; + /// Shape of all quant params in shared memory + using QuantParamsShapeB = QuantParamsShape; public: // @@ -162,11 +153,8 @@ class Wint2xMmaBase { /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for local_scale of B operand - AlignedBuffer operand_local_scale_B; - - /// Buffer for column-wise params of B operand - AlignedBuffer operand_column_wise_B; + /// Buffer for extra quant params of B operand + AlignedBuffer operand_quant_params_B; public: // diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index ca63f8c1d6..a006174811 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -46,6 +46,7 @@ #include "cutlass_extensions/arch/memory_copy_sm80.h" #include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,15 +87,15 @@ template < typename Policy_, /// Number of stages, int Stages, - /// Transform for input B applied in register after the LDS - typename TransformBAfterLDS_, + /// Accessor for extra quantized params + typename QuantParamsAccessor_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> class Wint2xMmaMultistage : - public Wint2xMmaBase { + public Wint2xMmaBase { public: ///< Base class - using Base = Wint2xMmaBase; + using Base = Wint2xMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory @@ -107,8 +108,9 @@ class Wint2xMmaMultistage : using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; - /// Transform for input B applied in register after the LDS - using TransformBAfterLDS = TransformBAfterLDS_; + /// Accessor for extra quantized params + using QuantParamsAccessor = QuantParamsAccessor_; + using QuantArguments = typename QuantParamsAccessor::Arguments; static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; @@ -131,17 +133,17 @@ class Wint2xMmaMultistage : /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; - using LayoutScale = cutlass::layout::ColumnMajor; + //using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout; + using LayoutScale = layout::RowMajor; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - using Dequantizer = + using WarpDequantizer = warp::MmaTensorOpWin2xDequantizer; - static_assert(sizeof(Dequantizer) > 0, "Dequantizer template instantiation failed"); + LayoutScale, + 32>; + static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed"); /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -188,16 +190,27 @@ class Wint2xMmaMultistage : using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale; + using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp; + using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale; + + /// channel-wise quant params + FragmentCodeScaleZp warp_frag_code_scale_; + FragmentCodeScaleZp warp_frag_code_zp_; + FragmentSuperScale warp_frag_super_scale_; + + /// group-wise quant params + FragmentLocalScale warp_frag_local_scale_; + /// Temporary accumulator to facilitate staged-accumulation FragmentC tmp_accum_; /// Pair of A fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentA warp_loaded_frag_A_[2]; - WarpTransformedFragmentA warp_transformed_frag_A_[2]; + WarpTransformedFragmentA warp_frag_A_[2]; /// Pair of B fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentB warp_loaded_frag_B_[2]; - WarpTransformedFragmentB warp_transformed_frag_B_[2]; + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_frag_B_; }; using ElementA = typename IteratorA::Element; @@ -218,23 +231,26 @@ class Wint2xMmaMultistage : /// Warp-level MMA operator Operator warp_mma_; - // Wint2 unzip operator - Dequantizer warp_dequantizer_; - /// Iterator to write threadblock-scoped tile of A operand to shared memory SmemIteratorA smem_iterator_A_; /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB smem_iterator_B_; + /// Accessor for extra quant params for B + QuantParamsAccessor quant_params_accessor_B_; + + // Wint2 unzip operator + WarpDequantizer warp_dequantizer_; + /// Shared memory write stage index int smem_write_stage_idx_; /// Shared memory read stage index int smem_read_stage_idx_; - /// Transform for B in register - TransformBAfterLDS transform_B_; + ElementA* smem_ptr_A_; + ElementA* ptr_A_; uint8_t* smem_ptr_B_; uint8_t* ptr_B_; @@ -252,10 +268,15 @@ class Wint2xMmaMultistage : int warp_idx, ///< ID of each thread within a warp int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), + ) : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx), + warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(), + quant_params_accessor_B_.local_scale_ref(), + quant_params_accessor_B_.code_scale_ref(), + quant_params_accessor_B_.code_zp_ref(), + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_write_stage_idx_(0), smem_read_stage_idx_(0) { @@ -285,16 +306,20 @@ class Wint2xMmaMultistage : CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}", Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn); - CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d", - shared_storage.operand_A.data(), static_cast(Base::SharedStorage::ShapeA::kRow), - static_cast(Base::SharedStorage::ShapeA::kColumn)); - CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVector=%d", + CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d", + shared_storage.operand_A.data(), + static_cast(Base::SharedStorage::ShapeA::kRow), static_cast(Base::SharedStorage::ShapeA::kColumn), + static_cast(sizeof(shared_storage.operand_A)), + static_cast(IteratorA::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorA::AccessType)), + static_cast(Detail::AsyncCopyIterationsPerStageA), static_cast(IteratorA::kAccessesPerVector)); + CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d", shared_storage.operand_B.data(), static_cast(Base::SharedStorage::ShapeB::kRow), static_cast(Base::SharedStorage::ShapeB::kColumn), static_cast(sizeof(shared_storage.operand_B)), static_cast(IteratorB::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorB::AccessType)), static_cast(Detail::AsyncCopyIterationsPerStageB), static_cast(IteratorB::kAccessesPerVector)); + smem_ptr_A_ = reinterpret_cast(shared_storage.operand_A.data()); smem_ptr_B_ = reinterpret_cast(shared_storage.operand_B.data()); } @@ -373,7 +398,6 @@ class Wint2xMmaMultistage : } } - template CUTLASS_DEVICE void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) { iterator_B.set_iteration_index(group_start_B * @@ -396,18 +420,20 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); +#if 0 if (group_start_B == 0 && j == 0 && v == 0) { CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d", reinterpret_cast(dst_ptr), reinterpret_cast(gmem_ptr), static_cast(Detail::kAccessesPerGroupB), static_cast(IteratorB::kAccessesPerVector), static_cast(sizeof(typename IteratorB::Element))); } +#endif if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( + cutlass::arch::cp_async_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); } else { - cutlass::arch::copy( + cutlass::arch::cp_async( dst_ptr + v, gmem_ptr, iterator_B.valid()); } @@ -421,7 +447,7 @@ class Wint2xMmaMultistage : } CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) { + void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A, int stage) { iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -443,6 +469,32 @@ class Wint2xMmaMultistage : int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); +#if 0 + if (v == 0) { + int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_A_); + int gmem_k = 8192; + int gmem_m = 16; + int gmem_row = gmem_offset / gmem_k; + int gmem_col = gmem_offset % gmem_k; + + int smem_offset = reinterpret_cast(dst_ptr) - reinterpret_cast(smem_ptr_A_); + int smem_k = Shape::kK; + int smem_m = Shape::kM; + int smem_row = smem_offset / smem_k; + int smem_col = smem_offset % smem_k; + + ElementA* gmem_element_A_ptr = reinterpret_cast(gmem_ptr); + CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%f, %f, %f, %f, %f, %f, %f, %f]; smem: {%d, %d};", + stage, reinterpret_cast(gmem_ptr), reinterpret_cast(dst_ptr), kSrcBytes, + gmem_m, gmem_k, gmem_row, gmem_col, + static_cast(gmem_element_A_ptr[0]), static_cast(gmem_element_A_ptr[1]), + static_cast(gmem_element_A_ptr[2]), static_cast(gmem_element_A_ptr[3]), + static_cast(gmem_element_A_ptr[4]), static_cast(gmem_element_A_ptr[5]), + static_cast(gmem_element_A_ptr[6]), static_cast(gmem_element_A_ptr[7]), + smem_row, smem_col); + } +#endif + cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); @@ -453,7 +505,7 @@ class Wint2xMmaMultistage : } } - template + template CUTLASS_DEVICE void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B, int stage) { iterator_B.set_iteration_index(0); @@ -475,6 +527,7 @@ class Wint2xMmaMultistage : IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; +#if 0 if (v == 0) { int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_B_); int gmem_k = 8192 * kInterleave / 4; @@ -499,16 +552,17 @@ class Wint2xMmaMultistage : static_cast(gmem_uint8_ptr[6]), static_cast(gmem_uint8_ptr[7]), smem_row, smem_col); } +#endif if (InitStage) { - cutlass::arch::copy_zfill( + cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); } else { if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( + cutlass::arch::cp_async_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); } else { - cutlass::arch::copy( + cutlass::arch::cp_async( dst_ptr + v, gmem_ptr, iterator_B.valid()); } } @@ -527,6 +581,7 @@ class Wint2xMmaMultistage : void prologue( IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args, ///< iterators for extra quant params for B int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages @@ -538,15 +593,21 @@ class Wint2xMmaMultistage : iterator_B.clear_mask(gemm_k_iterations == 0); // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_A(iterator_A); + copy_tiles_and_advance_per_stage_A(iterator_A, 0); // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_B(iterator_B, stage); + copy_tiles_and_advance_per_stage_B(iterator_B, stage); - // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. + // Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. + if (stage == 0) { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } else { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } // Move to the next write stage advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); @@ -614,40 +675,39 @@ class Wint2xMmaMultistage : FragmentC &accum, ///< [in|out] destination accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args, ///< iterators for extra quant params for B int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining int stage) { // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k); - // Load the next warp-tile's A fragment from shared memory this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; + int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; + int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB; + + if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; - // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); } // Execute the current warp-tile of MMA operations if (Detail::kStagedAccumulation) { + //CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B); warp_mma_( pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_, + pipe_state.tmp_accum_, + warp_k_compute_offset_B ); if (warp_mma_k == 0) { @@ -656,12 +716,27 @@ class Wint2xMmaMultistage : pipe_state.tmp_accum_.clear(); } } else { + //CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B); warp_mma_( accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_, + accum, + warp_k_compute_offset_B ); +#if 0 + if (FragmentC::kElements == 16) { + CUTLASS_TRACE_DEVICE(" tile_C[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(accum[0]), static_cast(accum[1]), + static_cast(accum[2]), static_cast(accum[3]), + static_cast(accum[4]), static_cast(accum[5]), + static_cast(accum[6]), static_cast(accum[7]), + static_cast(accum[8]), static_cast(accum[9]), + static_cast(accum[10]), static_cast(accum[11]), + static_cast(accum[12]), static_cast(accum[13]), + static_cast(accum[14]), static_cast(accum[15])); + } +#endif } // Except for the last warp-tile, all warp-tiles issue their share of @@ -671,7 +746,10 @@ class Wint2xMmaMultistage : int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + if (warp_mma_k == 0) { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } } // The second-to-last warp-tile also: @@ -683,7 +761,7 @@ class Wint2xMmaMultistage : int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); @@ -693,23 +771,27 @@ class Wint2xMmaMultistage : // Move to the next global fetch stage advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); + advance_smem_read_stage(); + int byte_offset = quant_params_accessor_B_.advance_smem_read_stage(); + warp_dequantizer_.add_pointer_offset(byte_offset); // Disable global fetching when done with global fetch iterations --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); } - // The last warp-tile also converts the shared memory fragments used by - // the first warp-tile of the next iteration, if necessary (so we can - // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_, + (stage - Base::kStages + 2) * Shape::kK); } } } @@ -721,18 +803,18 @@ class Wint2xMmaMultistage : int gemm_k_iterations, ///< number of threadblock mainloop iterations FragmentC &accum, ///< [in|out] accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B) + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args) { #if 0 - int smem_k = Shape::kK * kInterleave / 4; - int smem_n = Shape::kN / kInterleave; - for (int i = 0; i < 3 * smem_n; ++i) { - for (int j = 0; j < smem_k; ++j) { - if (i % 3 == 0) { - CUTLASS_TRACE_DEVICE(" [i=%d, j=%d, %dx%d] %d", i, j, smem_n, smem_k, static_cast(smem_ptr_B_[i * smem_k + j])); - } - } - } + CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentA::kElements=%d, %d bytes", + PipeState::WarpLoadedFragmentA::kElements, static_cast(sizeof_bits::value / 8)); + CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentB::kElements=%d, %d bytes", + PipeState::WarpLoadedFragmentB::kElements, static_cast(sizeof_bits::value / 8)); + CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentA::kElements=%d, %d bytes", + PipeState::WarpTransformedFragmentA::kElements, static_cast(sizeof_bits::value / 8)); + CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentB::kElements=%d, %d bytes", + PipeState::WarpTransformedFragmentB::kElements, static_cast(sizeof_bits::value / 8)); #endif PipeState pipe_state; @@ -743,25 +825,26 @@ class Wint2xMmaMultistage : // Load first warp-tile's A fragment from shared memory this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]); ++this->warp_tile_iterator_A_; // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); ++this->warp_tile_iterator_B_; if (PipeState::WarpLoadedFragmentA::kElements == 8) { - ElementA* warp_frag_A_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_A_[0].data()); - CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes", + ElementA* warp_frag_A_ptr = reinterpret_cast(pipe_state.warp_frag_A_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes", static_cast(warp_frag_A_ptr[0]), static_cast(warp_frag_A_ptr[1]), static_cast(warp_frag_A_ptr[2]), static_cast(warp_frag_A_ptr[3]), static_cast(warp_frag_A_ptr[4]), static_cast(warp_frag_A_ptr[5]), static_cast(warp_frag_A_ptr[6]), static_cast(warp_frag_A_ptr[7]), sizeof_bits::value / 8); } +#if 0 if (PipeState::WarpLoadedFragmentB::kElements == 64) { - uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_[0].data()); + uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_.data()); CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), @@ -773,75 +856,31 @@ class Wint2xMmaMultistage : static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), sizeof_bits::value / 8); } +#endif - typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_(pipe_state.warp_loaded_frag_B_[0]); - if (TransformBAfterLDS::result_type::kElements == 64) { - CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits::value / 8); - CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", - static_cast(unpacked_frag_B[0]), static_cast(unpacked_frag_B[1]), - static_cast(unpacked_frag_B[2]), static_cast(unpacked_frag_B[3]), - static_cast(unpacked_frag_B[4]), static_cast(unpacked_frag_B[5]), - static_cast(unpacked_frag_B[6]), static_cast(unpacked_frag_B[7]), - static_cast(unpacked_frag_B[8]), static_cast(unpacked_frag_B[9]), - static_cast(unpacked_frag_B[10]), static_cast(unpacked_frag_B[11]), - static_cast(unpacked_frag_B[12]), static_cast(unpacked_frag_B[13]), - static_cast(unpacked_frag_B[14]), static_cast(unpacked_frag_B[15])); - CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", - static_cast(unpacked_frag_B[16]), static_cast(unpacked_frag_B[17]), - static_cast(unpacked_frag_B[18]), static_cast(unpacked_frag_B[19]), - static_cast(unpacked_frag_B[20]), static_cast(unpacked_frag_B[21]), - static_cast(unpacked_frag_B[22]), static_cast(unpacked_frag_B[23]), - static_cast(unpacked_frag_B[24]), static_cast(unpacked_frag_B[25]), - static_cast(unpacked_frag_B[26]), static_cast(unpacked_frag_B[27]), - static_cast(unpacked_frag_B[28]), static_cast(unpacked_frag_B[29]), - static_cast(unpacked_frag_B[30]), static_cast(unpacked_frag_B[31])); - CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", - static_cast(unpacked_frag_B[32]), static_cast(unpacked_frag_B[33]), - static_cast(unpacked_frag_B[34]), static_cast(unpacked_frag_B[35]), - static_cast(unpacked_frag_B[36]), static_cast(unpacked_frag_B[37]), - static_cast(unpacked_frag_B[38]), static_cast(unpacked_frag_B[39]), - static_cast(unpacked_frag_B[40]), static_cast(unpacked_frag_B[41]), - static_cast(unpacked_frag_B[42]), static_cast(unpacked_frag_B[43]), - static_cast(unpacked_frag_B[44]), static_cast(unpacked_frag_B[45]), - static_cast(unpacked_frag_B[46]), static_cast(unpacked_frag_B[47])); - CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", - static_cast(unpacked_frag_B[48]), static_cast(unpacked_frag_B[49]), - static_cast(unpacked_frag_B[50]), static_cast(unpacked_frag_B[51]), - static_cast(unpacked_frag_B[52]), static_cast(unpacked_frag_B[53]), - static_cast(unpacked_frag_B[54]), static_cast(unpacked_frag_B[55]), - static_cast(unpacked_frag_B[56]), static_cast(unpacked_frag_B[57]), - static_cast(unpacked_frag_B[58]), static_cast(unpacked_frag_B[59]), - static_cast(unpacked_frag_B[60]), static_cast(unpacked_frag_B[61]), - static_cast(unpacked_frag_B[62]), static_cast(unpacked_frag_B[63])); - } + warp_dequantizer_.load(pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_); - typename Dequantizer::FragmentLocalScale warp_frag_local_scale; - typename Dequantizer::FragmentCodeScale warp_frag_code_scale; - typename Dequantizer::FragmentCodeZp warp_frag_code_zp; - typename Dequantizer::FragmentSuperScale warp_frag_super_scale; - typename Dequantizer::FragmentOutOperand warp_frag_out; - - CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load"); - warp_dequantizer_.load(warp_frag_local_scale, - warp_frag_code_scale, - warp_frag_code_zp, - warp_frag_super_scale); - - CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant"); - warp_dequantizer_.dequantize(warp_frag_out, - pipe_state.warp_loaded_frag_B_[0], - warp_frag_local_scale, - warp_frag_code_scale, - warp_frag_code_zp, - warp_frag_super_scale); + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); #if 0 - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); + if (PipeState::FragmentLocalScale::kElements == 4) { + CUTLASS_TRACE_DEVICE(" FragmentLocalScale::kElements=%d, local_scale_frag[0:3]=[%d, %d, %d, %d], sizeof(FragmentLocalScale)=%d", + PipeState::FragmentLocalScale::kElements, + static_cast(pipe_state.warp_frag_local_scale_[0]), static_cast(pipe_state.warp_frag_local_scale_[1]), + static_cast(pipe_state.warp_frag_local_scale_[2]), static_cast(pipe_state.warp_frag_local_scale_[3]), + static_cast(sizeof(PipeState::FragmentLocalScale))); + } +#endif + + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_, + 0); if (Detail::kStagedAccumulation) { pipe_state.tmp_accum_.clear(); @@ -852,15 +891,18 @@ class Wint2xMmaMultistage : // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { + if (stage > Base::kStages + 1) { + //break; + } mac_loop_iter( pipe_state, accum, iterator_A, iterator_B, + mma_quant_args, gemm_k_iterations, stage); stage += 1; - break; } if (Detail::kStagedAccumulation) { @@ -872,7 +914,6 @@ class Wint2xMmaMultistage : cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); - #endif } /// Prepares the class for another prologue. @@ -920,13 +961,16 @@ class Wint2xMmaMultistage : IteratorA iterator_A, ///< iterator over B operand in global memory IteratorB iterator_B, + ///< iterators for extra quant params for B + QuantArguments mma_quant_args, ///< initial value of accumulator FragmentC const &src_accum) { + ptr_A_ = reinterpret_cast(iterator_A.get_origin_pointer()); ptr_B_ = reinterpret_cast(iterator_B.get_origin_pointer()); // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, gemm_k_iterations); + prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations); // Wait until we have at least one completed global fetch stage gmem_wait(); @@ -935,7 +979,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h new file mode 100644 index 0000000000..194c06cf5d --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h @@ -0,0 +1,326 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/trace.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Original data type + typename T, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterators over super scales in global memory + typename IteratorSuperScale_, + /// Iterators over super scales in shared memory + typename SmemIteratorSuperScale_, + /// Iterators over local scales in global memory + typename IteratorLocalScale_, + /// Iterators over local scales in shared memory + typename SmemIteratorLocalScale_, + /// Iterators over code scales and zps in global memory + typename IteratorCodeScaleZp_, + /// Iterators over code scales and zps in shared memory + typename SmemIteratorCodeScaleZp_, + /// Number of stages, + int Stages_, + /// Group size for quantization + int GroupSize_> +class Wint2ParamsAccessor { +public: + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using ElementType = T; + using Shape = Shape_; + + using IteratorSuperScale = IteratorSuperScale_; + using SmemIteratorSuperScale = SmemIteratorSuperScale_; + + using IteratorLocalScale = IteratorLocalScale_; + using SmemIteratorLocalScale = SmemIteratorLocalScale_; + + using IteratorCodeScaleZp = IteratorCodeScaleZp_; + using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_; + + constexpr static int kStages = Stages_; + constexpr static int kGroupSize = GroupSize_; + + using ElementSuperScale = typename IteratorSuperScale::Element; + using LayoutSuperScale = typename IteratorSuperScale::Layout; + + // local_scale uint4 and group-wise + using ElementLocalScale = typename IteratorLocalScale::Element; + using LayoutLocalScale = typename IteratorLocalScale::Layout; + static_assert(platform::is_same::value, + "local_scale's type must be uint4b_t."); + + using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element; + using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout; + + // 2 uint4b_t values are stored in a single uint8_t + constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK; + constexpr static int kLocalScaleRows = IteratorLocalScale::Shape::kRow; + + using SmemElement = uint8_t; + constexpr static int kSmemRows = + sizeof(ElementLocalScale) * kLocalScaleRows * kStages + + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2; + constexpr static int kSmemColumns = Shape::kN; + + using QuantParamsShape = MatrixShape; + + constexpr static int kSuperScaleSmemOffset = 0; + constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale); + constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; + + struct Arguments { + IteratorSuperScale iterator_super_scale; + IteratorLocalScale iterator_local_scale; + IteratorCodeScaleZp iterator_code_scale; + IteratorCodeScaleZp iterator_code_zp; + + int local_scale_pointer_offset; + + CUTLASS_DEVICE + Arguments(IteratorSuperScale iterator_super_scale, + IteratorLocalScale iterator_local_scale, + IteratorCodeScaleZp iterator_code_scale, + IteratorCodeScaleZp iterator_code_zp, + int local_scale_pointer_offset) + : iterator_super_scale(iterator_super_scale), + iterator_local_scale(iterator_local_scale), + iterator_code_scale(iterator_code_scale), + iterator_code_zp(iterator_code_zp), + local_scale_pointer_offset(local_scale_pointer_offset) {} + }; + +private: + // + // Data members + // + + /// Begin address of shared memory + uint8_t* smem_pointer_; + + /// Iterator to write threadblock-scoped tile of super scale operand to shared memory + SmemIteratorSuperScale smem_iterator_super_scale_; + /// Iterator to write threadblock-scoped tile of local scale operand to shared memory + SmemIteratorLocalScale smem_iterator_local_scale_; + /// Iterator to write threadblock-scoped tile of code scale operand to shared memory + SmemIteratorCodeScaleZp smem_iterator_code_scale_; + /// Iterator to write threadblock-scoped tile of code zp operand to shared memory + SmemIteratorCodeScaleZp smem_iterator_code_zp_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + CUTLASS_DEVICE + ElementSuperScale* get_super_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kSuperScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementLocalScale* get_local_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kLocalScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kCodeScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_zp_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kCodeZpSmemOffset); + } + +public: + /// Construct from tensor references + CUTLASS_DEVICE + Wint2ParamsAccessor( + ///< prointer of shared memory + uint8_t* smem_pointer, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : smem_pointer_(smem_pointer), + smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn), + get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx), + smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn), + get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx), + smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), + smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, kSmemRows=%d, kSmemColumns=%d, kLocalScaleRows=%d, kStagesPerLocalScaleLoad=%d", + Shape::kM, Shape::kN, Shape::kK, kSmemRows, kSmemColumns, kLocalScaleRows, kStagesPerLocalScaleLoad); + //CUTLASS_TRACE_DEVICE(" IteratorSuperScale::Shape: {%d, %d}, kSuperScaleSmemOffset=%d, smem_ptr=%p", + // IteratorSuperScale::Shape::kRow, IteratorSuperScale::Shape::kColumn, kSuperScaleSmemOffset, get_super_scale_smem_ptr()); + //CUTLASS_TRACE_DEVICE(" IteratorLocalScale::Shape: {%d, %d}, kLocalScaleSmemOffset=%d, smem_ptr=%p", + // IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn, kLocalScaleSmemOffset, get_local_scale_smem_ptr()); + //CUTLASS_TRACE_DEVICE(" IteratorCodeScaleZp::Shape: {%d, %d}, kCodeScaleSmemOffset=%d, smem_ptr=%p", + // IteratorCodeScaleZp::Shape::kRow, IteratorCodeScaleZp::Shape::kColumn, kCodeScaleSmemOffset, get_code_scale_smem_ptr()); + //CUTLASS_TRACE_DEVICE(" IteratorCodeScaleZp::Shape: {%d, %d}, kCodeZpSmemOffset=%d, smem_ptr=%p", + // IteratorCodeScaleZp::Shape::kRow, IteratorCodeScaleZp::Shape::kColumn, kCodeZpSmemOffset, get_code_zp_smem_ptr()); + } + + CUTLASS_DEVICE + SuperTensorRef super_scale_ref() { + return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + LocalTensorRef local_scale_ref() { + return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_scale_ref() { + return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_zp_ref() { + return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + template + CUTLASS_DEVICE + void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) { + if constexpr (IsFirstStage) { + //CUTLASS_TRACE_DEVICE(" [stage=%d][SuperScale] Shape: {%d, %d}, Fragment::kElements=%d", + // stage, IteratorSuperScale::Shape::kRow, IteratorSuperScale::Shape::kColumn, IteratorSuperScale::Fragment::kElements); + //CUTLASS_TRACE_DEVICE(" [stage=%d][CodeScale] Shape: {%d, %d}, Fragment::kElements=%d", + // stage, IteratorCodeScaleZp::Shape::kRow, IteratorCodeScaleZp::Shape::kColumn, IteratorCodeScaleZp::Fragment::kElements); + + // Load channel-wise super_scale to shared memory, which only needs to be done once. + typename IteratorSuperScale::Fragment tb_frag_super_scale; + tb_frag_super_scale.clear(); + quant_args.iterator_super_scale.load(tb_frag_super_scale); + this->smem_iterator_super_scale_.store(tb_frag_super_scale); + + // Load channel-wise code_scale to shared memory, which only needs to be done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_scale; + tb_frag_code_scale.clear(); + quant_args.iterator_code_scale.load(tb_frag_code_scale); + this->smem_iterator_code_scale_.store(tb_frag_code_scale); + + // Load channel-wise code_zp to shared memory, which only needs to be done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_zp; + tb_frag_code_zp.clear(); + quant_args.iterator_code_zp.load(tb_frag_code_zp); + this->smem_iterator_code_zp_.store(tb_frag_code_zp); + } + + if ((stage % kStagesPerLocalScaleLoad) == 0) { + // Load group-wise local_scale to shared memory, which only needs to be done at each stage. + // Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages. + typename IteratorLocalScale::Fragment tb_frag_local_scale; + tb_frag_local_scale.clear(); + quant_args.iterator_local_scale.load(tb_frag_local_scale); + this->smem_iterator_local_scale_.store(tb_frag_local_scale); + + //CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] Shape: {%d, %d}", + // stage, IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn); +#if 0 + __syncthreads(); + if (IteratorLocalScale::Fragment::kElements == 32) { + uint8_t* local_scale_ptr = reinterpret_cast(tb_frag_local_scale.data()); + CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] tb_frag_local_scale[0:15]=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d]", + stage, + static_cast(local_scale_ptr[0]), static_cast(local_scale_ptr[1]), + static_cast(local_scale_ptr[2]), static_cast(local_scale_ptr[3]), + static_cast(local_scale_ptr[4]), static_cast(local_scale_ptr[5]), + static_cast(local_scale_ptr[6]), static_cast(local_scale_ptr[7]), + static_cast(local_scale_ptr[8]), static_cast(local_scale_ptr[9]), + static_cast(local_scale_ptr[10]), static_cast(local_scale_ptr[11]), + static_cast(local_scale_ptr[12]), static_cast(local_scale_ptr[13]), + static_cast(local_scale_ptr[14]), static_cast(local_scale_ptr[15])); + } +#endif + } + } + + CUTLASS_DEVICE + void advance_smem_write_stage(Arguments &quant_args) { + if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + // Advance global iterators + quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset); + + // Advance shared iterators + int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(pointer_offset); + smem_write_stage_idx_ = 0; + } + //CUTLASS_TRACE_DEVICE(" smem_write_stage_idx_=%d", smem_write_stage_idx_); + } + + CUTLASS_DEVICE + int advance_smem_read_stage() { + int byte_offset = 0; + if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + byte_offset = kLocalScaleRows * kSmemColumns; + } + + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + smem_read_stage_idx_ = 0; + byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns; + } + //CUTLASS_TRACE_DEVICE(" smem_read_stage_idx_=%d, byte_offset=%d", smem_read_stage_idx_, byte_offset); + return byte_offset; + } + + CUTLASS_DEVICE + int clear_mask(Arguments &quant_args, bool cond) { + quant_args.iterator_local_scale.clear_mask(cond); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h index 4d05740d53..1f81d5802b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -42,22 +42,16 @@ #include "cutlass/matrix_shape.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_ref.h" - #include "cutlass/arch/arch.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/gemm/gemm.h" - #include "cutlass/layout/matrix.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/tensor.h" - #include "cutlass/functional.h" #include "cutlass/platform/platform.h" -#include -#include "cutlass_extensions/weight_only_quant_op.h" - -//////////////////////////////////////////////////////////////////////////////// +#include "cutlass_extensions/interleaved_numeric_conversion.h" namespace cutlass { namespace gemm { @@ -73,16 +67,16 @@ template < /// Operand identity Operand Operand, /// Data type of Scale elements - typename Element_, + typename ElementOperand_, /// Layout of operand typename Layout_, /// Number of threads participating in one matrix operation int Threads, /// - WeightOnlyQuantOp QuantOp_ = WeightOnlyQuantOp::UNDEFINED, - /// typename Enable = void> -class MmaTensorOpWin2xDequantizer; +class MmaTensorOpWin2xDequantizer { + //static_assert(false, "Not Supported!"); +}; //////////////////////////////////////////////////////////////////////////////// // Bfloat specialization for Ampere @@ -91,20 +85,22 @@ template < typename MmaOperator_, /// Shape of the warp level matrix multiply (concept: GemmShape) typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> - + /// Data type of Scale elements + typename ElementOperand_> class MmaTensorOpWin2xDequantizer< MmaOperator_, Shape_, Operand::kB, - bfloat16_t, - layout::ColumnMajor, - 32, - QuantOp_, - typename platform::enable_if= - 70>::type> { - public: + ElementOperand_, + layout::RowMajor, + 32> + //typename platform::enable_if= 80 + // && platform::is_same::value>::type> +{ +public: + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + /// Mma Operator using MmaOperator = MmaOperator_; @@ -114,577 +110,266 @@ class MmaTensorOpWin2xDequantizer< // Mma Instruction Shape using InstructionShape = typename ArchMmaOperator::Shape; - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = - MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + /// Type of mma operand + using ElementOperand = ElementOperand_; - /// Type of the scales - using ElementWeight = uint2b_t; + /// Type of input + using ElementB = typename MmaOperator::FragmentB::Element; + static_assert(platform::is_same::value, "ElementB must be uint2b_t"); - /// Type of the scales - using ElementUnzipWeight = uint8_t; + /// Type of internal compute + using ElementCompute = float; /// Type of the scales - using ElementScale = bfloat16_t; - - /// Type of the scales - using ScaleComputeT = float; - - static constexpr int unzip_len = 4; + using ElementLocalScale = uint4b_t; + using ElementSuperScale = ElementOperand; + using ElementCodeScaleZp = float; /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = - Array; - using FragmentWeightOperand = - Array; - using FragmentOutOperand = - Array; + using FragmentInput = Array; + + /// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points. + using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementOperand, ElementB, MmaOperator::FragmentB::kElements>; + using FragmentUnpack = typename Uint2Converter::result_type; // Fragment to hold scale data to apply to B before mma // We need 1 fp16 per matrix iteration in the N dimension static constexpr int kColsPerMmaPerThread = 1; - using FragmentLocalScale = Array; - using FragmentCodeScale = Array; - using FragmentCodeZp = Array; - using FragmentSuperScale = Array; + static constexpr int kElements = kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn; + + // 32 bits are loaded to register from shared memory by each thread + static constexpr int kMmaIterationsPerLoad = + 32 / (sizeof_bits::value * ArchMmaOperator::FragmentB::kElements); + + // use uint8_t to save 2 4-bits local scales + using FragmentLocalScale = Array; + using FragmentSuperScale = Array; + using FragmentCodeScaleZp = Array; + + /// Fragment to hold internal scales before Mma + using FragmentCompute = Array; + + /// Fragment of dequantized B + //using FragmentOutput = Array; + using FragmentOutput = Array; /// Warp mma shape using Shape = Shape_; /// Layout of the scales in shared memory - using Layout = layout::ColumnMajor; + using Layout = layout::RowMajor; /// TensorRef type for loading element from a tensor - using TensorRef = cutlass::TensorRef; - using TensorCodeRef = cutlass::TensorRef; + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; +private: + // + // Data members + // + uint8_t* pointer_local_scale_; + ElementCodeScaleZp* pointer_code_scale_; + ElementCodeScaleZp* pointer_code_zp_; + ElementSuperScale* pointer_super_scale_; + + FragmentUnpack unpacked_frag_; + +public: CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, - TensorCodeRef smem_code_scale, - TensorCodeRef smem_code_zp, - TensorRef smem_super_scale, - int const warp_idx_n, - int const lane_idx) { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_local_scale_ = smem_local_scale.data() + thread_offset; + MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale, + LocalTensorRef smem_local_scale, + CodeTensorRef smem_code_scale, + CodeTensorRef smem_code_zp, + int warp_idx_n, + int lane_idx) { + int warp_offset = warp_idx_n * Shape::kN; + int quad = lane_idx / 4; + int thread_offset = warp_offset + quad; + pointer_super_scale_ = smem_super_scale.data() + thread_offset; pointer_code_scale_ = smem_code_scale.data() + thread_offset; pointer_code_zp_ = smem_code_zp.data() + thread_offset; - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = smem_super_scale.data() + thread_offset; - } + pointer_local_scale_ = reinterpret_cast(smem_local_scale.data()) + thread_offset; } - // CUTLASS_DEVICE - // MmaTensorOpWin2xDequantizer() { - // pointer_local_scale_ = nullptr; - // pointer_code_scale_ = nullptr; - // pointer_code_zp_ = nullptr; - // if constexpr (hasZero(QuantOp)) { - // pointer_super_scale_ = nullptr; - // } - // } - + /// Channel-wise params, need to load just once CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer() { - // Create fake pointer using a shared dummy buffer - CUTLASS_TRACE_DEVICE(" warp dequant aaa"); - - extern __shared__ char cutlass_fake_dequant_smem[]; - - // Memory layout (manual alignment): - // ElementScale (half or bf16): 2 bytes - // ScaleComputeT (float): 4 bytes - - pointer_local_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem); - pointer_code_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem + 64); - pointer_code_zp_ = - reinterpret_cast(cutlass_fake_dequant_smem + 128); - - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = reinterpret_cast( - cutlass_fake_dequant_smem + 192); + void load(FragmentCodeScaleZp& code_scale_frag, + FragmentCodeScaleZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict + code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN]; } } + /// Group-wise params, need to load multiple times CUTLASS_DEVICE - void load(FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { - CUTLASS_TRACE_DEVICE(" warp dequant load"); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * - // InstructionShape::kN]; code_scale_frag[mma_n_iter] = - // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; - // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * - // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) - // { - // super_scale_frag[mma_n_iter] = - // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; - // } - // } + void load(FragmentLocalScale& local_scale_frag) { + //CUTLASS_TRACE_DEVICE(" pointer_local_scale_=%p", pointer_local_scale_); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict + } } CUTLASS_DEVICE - void dequantize(FragmentOutOperand& out_frag, - FragmentDequantizedOperand& operand_frag, - FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - CUTLASS_TRACE_DEVICE(" dequantize if def"); - - static constexpr int32_t kGroupSize = 64; - static constexpr int32_t kPackNum = 4; - static constexpr int32_t kWeightMask = 0x3F; - static constexpr int32_t kLocalScaleMask = 0xF; - static constexpr int32_t kBZP = 32; - - // using _MmaOperandB = typename ArchMmaOperator::FragmentB; - // using ExpandedMmaOperandB = Array; - // static_assert(ExpandedMmaOperandB::kElements * - // MmaOperator::MmaIterations::kColumn - // == FragmentDequantizedOperand::kElements, - // ""); - - // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // MmaOperator::MmaIterations::kColumn); - - // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" - // FragmentDequantizedOperand::kElements = %d ", - // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" - // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); - - // FragmentWeightOperand - CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", - FragmentWeightOperand::kElements); - // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", - // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight - // size = %d ", sizeof(ElementWeight)); - static_assert(std::is_same::value, - "B 是 uint8 量化类型"); - FragmentWeightOperand* weight_ptr = - reinterpret_cast(&operand_frag); - FragmentLocalScale* local_scale_ptr = - reinterpret_cast(&local_scale_frag); - FragmentCodeScale* code_scale_ptr = - reinterpret_cast(&code_scale_frag); - FragmentCodeZp* code_zp_ptr = - reinterpret_cast(&code_zp_frag); - FragmentSuperScale* super_scale_ptr = - reinterpret_cast(&super_scale_frag); - - ScaleComputeT code_scale = - static_cast(code_scale_ptr[0][0]); - ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); - ScaleComputeT super_scale = - static_cast(super_scale_ptr[0][0]); - int32_t local_scale = static_cast(local_scale_ptr[0][0]); - int32_t const shift_bits[4] = {9, 6, 3, 0}; - - ScaleComputeT zipped_value[16]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - zipped_value[i] = static_cast(weight_ptr[0][i]); + void dequantize(const FragmentLocalScale& local_scale_frag, + const FragmentCodeScaleZp& code_scale_frag, + const FragmentCodeScaleZp& code_zp_frag, + const FragmentSuperScale& super_scale_frag, + const FragmentInput& input_frag, + FragmentOutput& output_frag, + int tb_offset_k) { + int stage = tb_offset_k / 64; + + //CUTLASS_TRACE_DEVICE(" FragmentInput::kElements=%d, %d bytes", + // FragmentInput::kElements, static_cast(sizeof_bits::value / 8)); + //CUTLASS_TRACE_DEVICE(" FragmentUnpack::kElements=%d, %d bytes", + // FragmentUnpack::kElements, static_cast(sizeof_bits::value / 8)); + //CUTLASS_TRACE_DEVICE(" FragmentOutput::kElements=%d, %d bytes", + // FragmentOutput::kElements, static_cast(sizeof_bits::value / 8)); + + //CUTLASS_TRACE_DEVICE(" MmaOperator::FragmentB::kElements=%d", MmaOperator::FragmentB::kElements); + //CUTLASS_TRACE_DEVICE(" MmaOperator::IteratorB::InstructionShape: %dx%d; InstructionShape: %dx%dx%d; ", + // MmaOperator::IteratorB::InstructionShape::kRow, MmaOperator::IteratorB::InstructionShape::kColumn, + // InstructionShape::kM, InstructionShape::kN, InstructionShape::kK); + //CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations: kRow=%d, kColumn=%d", + // MmaOperator::MmaIterations::kRow, MmaOperator::MmaIterations::kColumn); + + unpacked_frag_ = Uint2Converter::convert(input_frag); + // DEBUG CODES + for (int i = 0; i < FragmentUnpack::kElements; ++i) { + unpacked_frag_[i] = static_cast(1); //static_cast((i / 16) * 8 + (threadIdx.x % 32) / 4); } - int local_scale_shift = 4; - int32_t shifted_local_scale = - (local_scale >> local_scale_shift) & kLocalScaleMask; - ScaleComputeT scale = - static_cast(shifted_local_scale) * super_scale; - -#pragma unroll - for (int i = 0; i < 16; ++i) { - int32_t decode_value = static_cast( - floor(zipped_value[i] * code_scale + code_zp + - static_cast(0.5))); - - int col = i * 4; - -#pragma unroll - for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { - int32_t shift_bit = shift_bits[shift_bit_id]; - int32_t shifted_value = - (decode_value >> shift_bit) & kWeightMask; - - ScaleComputeT value = - static_cast(shifted_value - kBZP); - out_frag[col + shift_bit_id] = - static_cast(scale * value); - } +#if 0 + if (FragmentUnpack::kElements == 64) { + CUTLASS_TRACE_DEVICE(" unpacked_frag_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_[0]), static_cast(unpacked_frag_[1]), + static_cast(unpacked_frag_[2]), static_cast(unpacked_frag_[3]), + static_cast(unpacked_frag_[4]), static_cast(unpacked_frag_[5]), + static_cast(unpacked_frag_[6]), static_cast(unpacked_frag_[7]), + static_cast(unpacked_frag_[8]), static_cast(unpacked_frag_[9]), + static_cast(unpacked_frag_[10]), static_cast(unpacked_frag_[11]), + static_cast(unpacked_frag_[12]), static_cast(unpacked_frag_[13]), + static_cast(unpacked_frag_[14]), static_cast(unpacked_frag_[15])); + CUTLASS_TRACE_DEVICE(" unpacked_frag_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_[16]), static_cast(unpacked_frag_[17]), + static_cast(unpacked_frag_[18]), static_cast(unpacked_frag_[19]), + static_cast(unpacked_frag_[20]), static_cast(unpacked_frag_[21]), + static_cast(unpacked_frag_[22]), static_cast(unpacked_frag_[23]), + static_cast(unpacked_frag_[24]), static_cast(unpacked_frag_[25]), + static_cast(unpacked_frag_[26]), static_cast(unpacked_frag_[27]), + static_cast(unpacked_frag_[28]), static_cast(unpacked_frag_[29]), + static_cast(unpacked_frag_[30]), static_cast(unpacked_frag_[31])); } - - CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", - kColsPerMmaPerThread); - CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", - MmaOperator::MmaIterations::kColumn); - - // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 - // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = - // reinterpret_cast(&operand_frag); - - // printf("threadidx.x = %d\n", threadIdx.x); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - // __nv_bfloat162 scalex2 = - // __bfloat162bfloat162(scale_ptr[mma_n_iter]); - // __nv_bfloat162* operand_bf16x2_ptr = - // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - // CUTLASS_PRAGMA_UNROLL - // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - // { - // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], - // scalex2); - // } - // } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on - // older arch, scale conversion should happen before scales are stored - // to shared memory and we should use the fp16 dequantizer. This will - // avoid numerous conversion instructions in GEMM main loop. - CUTLASS_TRACE_DEVICE(" dequantize else def"); - // arch::device_breakpoint(); #endif - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_local_scale_ += offset; - pointer_code_scale_ += offset; - pointer_code_zp_ += offset; - pointer_super_scale_ += offset; - } - - private: - ElementScale const* pointer_local_scale_; - ScaleComputeT const* pointer_code_scale_; - ScaleComputeT const* pointer_code_zp_; - ElementScale const* pointer_super_scale_; - - ElementScale const* pointer_out_; -}; - -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> -class MmaTensorOpWin2xDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - half_t, - layout::ColumnMajor, - 32, - QuantOp_, - typename platform::enable_if= - 70>::type> { - public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = - MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementWeight = uint2b_t; - - /// Type of the scales - using ElementUnzipWeight = uint8_t; - - /// Type of the scales - using ElementScale = half_t; - - /// Type of the scales - using ScaleComputeT = float; - - static constexpr int unzip_len = 4; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = - Array; - using FragmentWeightOperand = - Array; - using FragmentOutOperand = - Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentLocalScale = Array; - using FragmentCodeScale = Array; - using FragmentCodeZp = Array; - using FragmentSuperScale = Array; - /// Warp mma shape - using Shape = Shape_; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - /// Layout of the scales in shared memory - using Layout = layout::ColumnMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = cutlass::TensorRef; - using TensorCodeRef = cutlass::TensorRef; + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kLocalScaleMask = 0xF; - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // special for TileRows = 64 + int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4; + FragmentCompute scale_frag; - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, - TensorCodeRef smem_code_scale, - TensorCodeRef smem_code_zp, - TensorRef smem_super_scale, - int const warp_idx_n, - int const lane_idx) { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_local_scale_ = smem_local_scale.data() + thread_offset; - pointer_code_scale_ = smem_code_scale.data() + thread_offset; - pointer_code_zp_ = smem_code_zp.data() + thread_offset; - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = smem_super_scale.data() + thread_offset; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentLocalScale::kElements; ++i) { + int32_t shifted_local_scale = + (static_cast(local_scale_frag[i]) >> local_scale_shift) & kLocalScaleMask; + scale_frag[i] = + static_cast(shifted_local_scale) * static_cast(super_scale_frag[i]); } - } - - // CUTLASS_DEVICE - // MmaTensorOpWin2xDequantizer() { - // pointer_local_scale_ = nullptr; - // pointer_code_scale_ = nullptr; - // pointer_code_zp_ = nullptr; - // if constexpr (hasZero(QuantOp)) { - // pointer_super_scale_ = nullptr; - // } - // } - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer() { - // Create fake pointer using a shared dummy buffer - CUTLASS_TRACE_DEVICE(" warp dequant aaa"); - - extern __shared__ char cutlass_fake_dequant_smem[]; - - // Memory layout (manual alignment): - // ElementScale (half or bf16): 2 bytes - // ScaleComputeT (float): 4 bytes - - pointer_local_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem); - pointer_code_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem + 64); - pointer_code_zp_ = - reinterpret_cast(cutlass_fake_dequant_smem + 128); - - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = reinterpret_cast( - cutlass_fake_dequant_smem + 192); +#if 1 + if (FragmentCompute::kElements == 4) { + CUTLASS_TRACE_DEVICE(" [stage=%d] tb_offset_k=%d, local_scale_shift=%d, scale_frag[0:3]=[%f, %f, %f, %f], sizeof(FragmentCompute)=%d bytes", + stage, tb_offset_k, local_scale_shift, + static_cast(scale_frag[0]), static_cast(scale_frag[1]), + static_cast(scale_frag[2]), static_cast(scale_frag[3]), + static_cast(sizeof(FragmentCompute))); } - } +#endif - CUTLASS_DEVICE - void load(FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { - CUTLASS_TRACE_DEVICE(" warp dequant load"); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * - // InstructionShape::kN]; code_scale_frag[mma_n_iter] = - // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; - // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * - // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) - // { - // super_scale_frag[mma_n_iter] = - // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; - // } - // } - } + //int offset = warp_mma_k * ArchMmaOperator::FragmentB::kElements; + int num_columns = 32 / sizeof_bits::value; - CUTLASS_DEVICE - void dequantize(FragmentOutOperand& out_frag, - FragmentDequantizedOperand& operand_frag, - FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - CUTLASS_TRACE_DEVICE(" dequantize if def"); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { - static constexpr int32_t kGroupSize = 64; - static constexpr int32_t kPackNum = 4; - static constexpr int32_t kWeightMask = 0x3F; - static constexpr int32_t kLocalScaleMask = 0xF; - static constexpr int32_t kBZP = 32; - - // using _MmaOperandB = typename ArchMmaOperator::FragmentB; - // using ExpandedMmaOperandB = Array; - // static_assert(ExpandedMmaOperandB::kElements * - // MmaOperator::MmaIterations::kColumn - // == FragmentDequantizedOperand::kElements, - // ""); - - // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // MmaOperator::MmaIterations::kColumn); - - // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" - // FragmentDequantizedOperand::kElements = %d ", - // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" - // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); - - // FragmentWeightOperand - CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", - FragmentWeightOperand::kElements); - // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", - // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight - // size = %d ", sizeof(ElementWeight)); - static_assert(std::is_same::value, - "B 是 uint8 量化类型"); - FragmentWeightOperand* weight_ptr = - reinterpret_cast(&operand_frag); - FragmentLocalScale* local_scale_ptr = - reinterpret_cast(&local_scale_frag); - FragmentCodeScale* code_scale_ptr = - reinterpret_cast(&code_scale_frag); - FragmentCodeZp* code_zp_ptr = - reinterpret_cast(&code_zp_frag); - FragmentSuperScale* super_scale_ptr = - reinterpret_cast(&super_scale_frag); - - ScaleComputeT code_scale = - static_cast(code_scale_ptr[0][0]); - ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); - ScaleComputeT super_scale = - static_cast(super_scale_ptr[0][0]); - int32_t local_scale = static_cast(local_scale_ptr[0][0]); - int32_t const shift_bits[4] = {9, 6, 3, 0}; - - ScaleComputeT zipped_value[16]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - zipped_value[i] = static_cast(weight_ptr[0][i]); - } - - int local_scale_shift = 4; - int32_t shifted_local_scale = - (local_scale >> local_scale_shift) & kLocalScaleMask; - ScaleComputeT scale = - static_cast(shifted_local_scale) * super_scale; - -#pragma unroll - for (int i = 0; i < 16; ++i) { - int32_t decode_value = static_cast( - floor(zipped_value[i] * code_scale + code_zp + - static_cast(0.5))); - - int col = i * 4; - -#pragma unroll - for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { - int32_t shift_bit = shift_bits[shift_bit_id]; - int32_t shifted_value = - (decode_value >> shift_bit) & kWeightMask; - - ScaleComputeT value = - static_cast(shifted_value - kBZP); - out_frag[col + shift_bit_id] = - static_cast(scale * value); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < num_columns; ++j) { + ElementCompute scaled_value = + static_cast(unpacked_frag_[mma_n_iter * num_columns + j]) * scale_frag[mma_n_iter]; + output_frag[mma_n_iter * num_columns + j] = static_cast(scaled_value); } } - CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", - kColsPerMmaPerThread); - CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", - MmaOperator::MmaIterations::kColumn); - - // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 - // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = - // reinterpret_cast(&operand_frag); - - // printf("threadidx.x = %d\n", threadIdx.x); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - // __nv_bfloat162 scalex2 = - // __bfloat162bfloat162(scale_ptr[mma_n_iter]); - // __nv_bfloat162* operand_bf16x2_ptr = - // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - // CUTLASS_PRAGMA_UNROLL - // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - // { - // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], - // scalex2); - // } - // } + if (FragmentOutput::kElements == 64) { +#if 1 + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[0]), static_cast(output_frag[1]), + static_cast(output_frag[2]), static_cast(output_frag[3]), + static_cast(output_frag[4]), static_cast(output_frag[5]), + static_cast(output_frag[6]), static_cast(output_frag[7]), + static_cast(output_frag[8]), static_cast(output_frag[9]), + static_cast(output_frag[10]), static_cast(output_frag[11]), + static_cast(output_frag[12]), static_cast(output_frag[13]), + static_cast(output_frag[14]), static_cast(output_frag[15])); + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[16]), static_cast(output_frag[17]), + static_cast(output_frag[18]), static_cast(output_frag[19]), + static_cast(output_frag[20]), static_cast(output_frag[21]), + static_cast(output_frag[22]), static_cast(output_frag[23]), + static_cast(output_frag[24]), static_cast(output_frag[25]), + static_cast(output_frag[26]), static_cast(output_frag[27]), + static_cast(output_frag[28]), static_cast(output_frag[29]), + static_cast(output_frag[30]), static_cast(output_frag[31])); + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[32]), static_cast(output_frag[33]), + static_cast(output_frag[34]), static_cast(output_frag[35]), + static_cast(output_frag[36]), static_cast(output_frag[37]), + static_cast(output_frag[38]), static_cast(output_frag[39]), + static_cast(output_frag[40]), static_cast(output_frag[41]), + static_cast(output_frag[42]), static_cast(output_frag[43]), + static_cast(output_frag[44]), static_cast(output_frag[45]), + static_cast(output_frag[46]), static_cast(output_frag[47])); + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[48]), static_cast(output_frag[49]), + static_cast(output_frag[50]), static_cast(output_frag[51]), + static_cast(output_frag[52]), static_cast(output_frag[53]), + static_cast(output_frag[54]), static_cast(output_frag[55]), + static_cast(output_frag[56]), static_cast(output_frag[57]), + static_cast(output_frag[58]), static_cast(output_frag[59]), + static_cast(output_frag[60]), static_cast(output_frag[61]), + static_cast(output_frag[62]), static_cast(output_frag[63])); +#endif + } #else // Slow path not implemented here on purpose. If we need to do HMMA on // older arch, scale conversion should happen before scales are stored // to shared memory and we should use the fp16 dequantizer. This will // avoid numerous conversion instructions in GEMM main loop. - CUTLASS_TRACE_DEVICE(" dequantize else def"); - // arch::device_breakpoint(); + arch::device_breakpoint(); #endif } - // Adds a pointer offset in units of elements. + /// Add an offset to pointer in units of elements. + /// Only group-wise params needs. CUTLASS_DEVICE void add_pointer_offset(int64_t const& offset) { - static_assert(sizeof(ElementScale) > 1, ""); pointer_local_scale_ += offset; - pointer_code_scale_ += offset; - pointer_code_zp_ += offset; - pointer_super_scale_ += offset; } - - private: - ElementScale const* pointer_local_scale_; - ScaleComputeT const* pointer_code_scale_; - ScaleComputeT const* pointer_code_zp_; - ElementScale const* pointer_super_scale_; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 1f5584862b..55548cf393 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -462,9 +462,11 @@ struct FastInterleavedAndBiasedNumericArrayConverter ScaleComputeT code_zp = static_cast(0); ScaleComputeT floor_offset = static_cast(0.5); +#if 0 CUTLASS_TRACE_DEVICE(" source: [%d, %d, %d, %d]", static_cast(in_ptr[0]), static_cast(in_ptr[1]), static_cast(in_ptr[2]), static_cast(in_ptr[3])); +#endif CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 4; ++i) { diff --git a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h index 9e1c6c463b..fa28810697 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h @@ -125,10 +125,13 @@ struct WintQuantTraits { static constexpr int32_t kNumPackedValues = 4; static constexpr int32_t kPackedSize = 16; + using LocalScaleType = uint4b_t; + using CodeScaleZpType = float; + struct Arguments { - const uint8_t *local_scale_ptr; // quanted 4-bits - const float *code_scale_ptr; - const float *code_zp_ptr; + uint8_t *local_scale_ptr; // quanted 4-bits + float *code_scale_ptr; + float *code_zp_ptr; }; CUTLASS_DEVICE diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 0d9aa62b3f..06931d485a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -804,17 +804,54 @@ struct Wint2xMoeFCGemm : public MoeFCGemm struct KernelRunner { using WeightQuantTraits = WintQuantTraits; - using QuantArguments = typename WeightQuantTraits::Arguments; + using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments; CUTLASS_DEVICE - static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) { - QuantArguments quant_args; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128; - quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n; - quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n; - } - return quant_args; + static MmaQuantArguments prepare_quant_args( + Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset, + int32_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) { + // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2}; + + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale( + Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n), + weight_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2); + uint4b_t *local_scale_ptr = reinterpret_cast(params.local_scale + problem_idx * gemm_k * gemm_n / 128); + CUTLASS_TRACE_DEVICE(" local_scale_ptr=%p, extent={%d, %d}, tb_offset={%d, %d}, local_scale_pointer_offset=%d", + local_scale_ptr, gemm_k / 128, gemm_n * 2, tb_offset_local_scale.row(), tb_offset_local_scale.column(), local_scale_pointer_offset); + typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale( + Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2), + local_scale_ptr, + {(gemm_k + 127) / 128, gemm_n * 2}, + thread_idx, + tb_offset_local_scale); + + float* code_scale_ptr = params.code_scale + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + float* code_zp_ptr = params.code_zp + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_zp_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + MmaQuantArguments mma_quant_args( + iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset); + return mma_quant_args; } CUTLASS_DEVICE @@ -861,7 +898,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm 2) { + if (problem_idx > 20) { break; } @@ -876,12 +913,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(byte_ptr_B); typename LayoutB::LongIndex ldm_B = platform::is_same::value ? gemm_n : gemm_k * kInterleave; // the begin threadblock_offset of B, which holds the same column id with C - cutlass::MatrixCoord tb_offset_B{0, - threadblock_offset.n() / kInterleave}; - + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - CUTLASS_TRACE_DEVICE(" ldm_B: %d, tb_offset_B: {%d, %d}, extent_B: {%d, %d}", static_cast(ldm_B), tb_offset_B.row(), tb_offset_B.column(), extent_B.row(), extent_B.column()); - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - // Compute position within threadblock int thread_idx = threadIdx.x; @@ -941,14 +964,15 @@ struct Wint2xMoeFCGemm : public MoeFCGemmdata(); - ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data(); - ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data(); - ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data(); - ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data(); - ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data(); + ffn1_quant_args.local_scale_ptr = const_cast(ffn1_local_scale->data()); + ffn1_quant_args.code_scale_ptr = const_cast(ffn1_code_scale->data()); + ffn1_quant_args.code_zp_ptr = const_cast(ffn1_code_zp->data()); + + ffn2_quant_args.local_scale_ptr = const_cast(ffn2_local_scale->data()); + ffn2_quant_args.code_scale_ptr = const_cast(ffn2_code_scale->data()); + ffn2_quant_args.code_zp_ptr = const_cast(ffn2_code_zp->data()); } auto moe_gemm_runner = MoeGemmRunner(); From 0b6068963d309288ac62f70b30c3711e3929d82d Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Tue, 8 Jul 2025 16:39:47 +0800 Subject: [PATCH 09/11] Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca --- .../gemm/threadblock/default_wint2x_mma.h | 6 +- .../gemm/threadblock/wint2x_mma_base.h | 38 +- .../gemm/threadblock/wint2x_mma_multistage.h | 331 ++++++--- .../warp/mma_tensorop_wint2x_dequantizer.h | 696 ------------------ .../interleaved_numeric_conversion.h | 137 +++- .../moe_gemm/fused_moe_cutlass_kernel.h | 9 +- 6 files changed, 409 insertions(+), 808 deletions(-) delete mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h index a67c8aa256..72c22a175f 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -19,6 +19,7 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/interleaved_numeric_conversion.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" namespace cutlass { @@ -156,13 +157,16 @@ struct DefaultWint2xMma; + using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< + ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, - typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>; + typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 7dec56be29..cdb465c60c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -93,6 +93,15 @@ class Wint2xMmaBase { static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + /// Number of warp-level GEMM oeprations per load for B + static constexpr int kWarpGemmIterationsPerLoadForB = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); + + static constexpr int kWarpLoadIterationsForB = + kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; + + /// Number of stages static int const kStages = Stages; @@ -131,16 +140,16 @@ class Wint2xMmaBase { using ShapeB = MatrixShape; - // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127) / 128; + // local_scale uint4 + constexpr static int kGroupWiseParamRows = Shape::kK / 64; + + using GroupWiseParamShapeB = MatrixShape; // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + + constexpr static int kColumnWiseParamRows = 2 * sizeof(float) + sizeof_bits::value / 8; - using ZippedShapeB = MatrixShape; - - using NopaddingShapeB = MatrixShape; + using ColumnWiseParamShapeB = MatrixShape; public: // @@ -153,12 +162,11 @@ class Wint2xMmaBase { /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for quanted B operand - AlignedBuffer operand_zipped_B; + /// Buffer for local_scale of B operand + AlignedBuffer operand_local_scale_B; - /// Buffer for unzip B operand - AlignedBuffer - operand_unzip_B; + /// Buffer for column-wise params of B operand + AlignedBuffer operand_column_wise_B; public: // @@ -188,14 +196,6 @@ class Wint2xMmaBase { TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - - CUTLASS_HOST_DEVICE - uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } - - CUTLASS_HOST_DEVICE - typename Operator::ElementB *operand_unzip_B_ptr() { - return operand_unzip_B.data(); - } }; protected: diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 1038f4220b..7f5b4a22e3 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -44,7 +44,6 @@ #include "cutlass/numeric_types.h" #include "cutlass_extensions/arch/memory_copy_sm80.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,10 +85,10 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Transform for input B applied in register after the LDS + typename TransformBAfterLDS_, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> class Wint2xMmaMultistage : public Wint2xMmaBase { public: @@ -107,8 +106,10 @@ class Wint2xMmaMultistage : using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; + /// Transform for input B applied in register after the LDS + using TransformBAfterLDS = TransformBAfterLDS_; - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; using SmemIteratorA = SmemIteratorA_; using SmemIteratorB = SmemIteratorB_; @@ -131,16 +132,6 @@ class Wint2xMmaMultistage : using LayoutScale = cutlass::layout::ColumnMajor; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; - using ElementB = typename WarpTransformedFragmentB::Element; - using Dequantizer = - warp::MmaTensorOpWin2xDequantizer; - static_assert(sizeof(Dequantizer) > 0, "Dequantizer template instantiation failed"); /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -199,6 +190,14 @@ class Wint2xMmaMultistage : WarpTransformedFragmentB warp_transformed_frag_B_[2]; }; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool IsTileInterleaveLayout = + layout::IsColumnMajorTileInterleave::value; + static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); private: @@ -209,9 +208,6 @@ class Wint2xMmaMultistage : /// Warp-level MMA operator Operator warp_mma_; - // Wint2 unzip operator - Dequantizer warp_dequantizer_; - /// Iterator to write threadblock-scoped tile of A operand to shared memory SmemIteratorA smem_iterator_A_; @@ -224,10 +220,11 @@ class Wint2xMmaMultistage : /// Shared memory read stage index int smem_read_stage_idx_; - uint8_t* column_wise_smem_ptr_B_; + /// Transform for B in register + TransformBAfterLDS transform_B_; - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; + uint8_t* smem_ptr_B_; + uint8_t* ptr_B_; public: @@ -261,16 +258,31 @@ class Wint2xMmaMultistage : int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d", + Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave); + CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d", + Policy::kPartitionsK, Base::kWarpGemmIterations, + Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k); + // Add per-warp offsets in units of warp-level tiles this->warp_tile_iterator_A_.add_tile_offset( {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); this->warp_tile_iterator_B_.add_tile_offset( {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; + CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}", + Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn); + CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d", + shared_storage.operand_A.data(), static_cast(Base::SharedStorage::ShapeA::kRow), + static_cast(Base::SharedStorage::ShapeA::kColumn)); + CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVector=%d", + shared_storage.operand_B.data(), + static_cast(Base::SharedStorage::ShapeB::kRow), static_cast(Base::SharedStorage::ShapeB::kColumn), + static_cast(sizeof(shared_storage.operand_B)), + static_cast(IteratorB::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorB::AccessType)), + static_cast(Detail::AsyncCopyIterationsPerStageB), static_cast(IteratorB::kAccessesPerVector)); + + smem_ptr_B_ = reinterpret_cast(shared_storage.operand_B.data()); } /// Advance shared memory read-iterators to the next stage @@ -371,6 +383,13 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); + if (group_start_B == 0 && j == 0 && v == 0) { + CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d", + reinterpret_cast(dst_ptr), reinterpret_cast(gmem_ptr), + static_cast(Detail::kAccessesPerGroupB), static_cast(IteratorB::kAccessesPerVector), + static_cast(sizeof(typename IteratorB::Element))); + } + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { cutlass::arch::copy_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); @@ -423,7 +442,7 @@ class Wint2xMmaMultistage : template CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B, int stage) { iterator_B.set_iteration_index(0); this->smem_iterator_B_.set_iteration_index(0); @@ -443,6 +462,31 @@ class Wint2xMmaMultistage : IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + if (v == 0) { + int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_B_); + int gmem_k = 8192 * kInterleave / 4; + int gmem_n = 1792 / kInterleave; + int gmem_row = gmem_offset / gmem_k; + int gmem_col = gmem_offset % gmem_k; + + int smem_offset = reinterpret_cast(dst_ptr) - reinterpret_cast(smem_ptr_B_); + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + int smem_row = smem_offset / smem_k; + int smem_col = smem_offset % smem_k; + + uint8_t* gmem_uint8_ptr = reinterpret_cast(gmem_ptr); + + CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};", + stage, reinterpret_cast(gmem_ptr), reinterpret_cast(dst_ptr), kSrcBytes, + gmem_n, gmem_k, gmem_row, gmem_col, + static_cast(gmem_uint8_ptr[0]), static_cast(gmem_uint8_ptr[1]), + static_cast(gmem_uint8_ptr[2]), static_cast(gmem_uint8_ptr[3]), + static_cast(gmem_uint8_ptr[4]), static_cast(gmem_uint8_ptr[5]), + static_cast(gmem_uint8_ptr[6]), static_cast(gmem_uint8_ptr[7]), + smem_row, smem_col); + } + if (InitStage) { cutlass::arch::copy_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); @@ -484,7 +528,7 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_B(iterator_B); + copy_tiles_and_advance_per_stage_B(iterator_B, stage); // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. @@ -560,6 +604,7 @@ class Wint2xMmaMultistage : int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining int stage) { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { @@ -576,35 +621,83 @@ class Wint2xMmaMultistage : ++this->warp_tile_iterator_B_; // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); - } - + // if (warp_mma_k > 0) { + // warp_mma_.transform( + // pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + // pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + // pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + // pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + // } + + static constexpr int kNumKIterationsPerWarpBLoad + = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + int const warp_tileB_k_compute_offset = warp_mma_k % kNumKIterationsPerWarpBLoad; // Execute the current warp-tile of MMA operations - if (Detail::kStagedAccumulation) { - warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + + // CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements); + // CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements); + // CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation); + + // uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data()); + // CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", + // static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), + // static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), + // static_cast(reg_uint8_ptr[4]), static_cast(reg_uint8_ptr[5]), + // static_cast(reg_uint8_ptr[6]), static_cast(reg_uint8_ptr[7]), + // static_cast(reg_uint8_ptr[8]), static_cast(reg_uint8_ptr[9]), + // static_cast(reg_uint8_ptr[10]), static_cast(reg_uint8_ptr[11]), + // static_cast(reg_uint8_ptr[12]), static_cast(reg_uint8_ptr[13]), + // static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), + // sizeof_bits::value / 8); + + typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + if (Detail::kStagedAccumulation) { + run_warp_mma(warp_mma_, + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + unpacked_frag_B, + pipe_state.tmp_accum_, + warp_tileB_k_compute_offset); if (warp_mma_k == 0) { plus plus_accum; accum = plus_accum(accum, pipe_state.tmp_accum_); pipe_state.tmp_accum_.clear(); } } else { - warp_mma_( - accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + // CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]", + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]), + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]), + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]), + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7])); + + // CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + // static_cast(unpacked_frag_B[0]), static_cast(unpacked_frag_B[1]), + // static_cast(unpacked_frag_B[2]), static_cast(unpacked_frag_B[3]), + // static_cast(unpacked_frag_B[4]), static_cast(unpacked_frag_B[5]), + // static_cast(unpacked_frag_B[6]), static_cast(unpacked_frag_B[7]), + // static_cast(unpacked_frag_B[8]), static_cast(unpacked_frag_B[9]), + // static_cast(unpacked_frag_B[10]), static_cast(unpacked_frag_B[11]), + // static_cast(unpacked_frag_B[12]), static_cast(unpacked_frag_B[13]), + // static_cast(unpacked_frag_B[14]), static_cast(unpacked_frag_B[15])); + + run_warp_mma(warp_mma_, + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + unpacked_frag_B, + accum, + warp_tileB_k_compute_offset); + + // CUTLASS_TRACE_DEVICE_TID(" warp_tileB_k_compute_offset = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + // warp_tileB_k_compute_offset, + // static_cast(accum[0]), static_cast(accum[1]), + // static_cast(accum[2]), static_cast(accum[3]), + // static_cast(accum[4]), static_cast(accum[5]), + // static_cast(accum[6]), static_cast(accum[7]), + // static_cast(accum[8]), static_cast(accum[9]), + // static_cast(accum[10]), static_cast(accum[11]), + // static_cast(accum[12]), static_cast(accum[13]), + // static_cast(accum[14]), static_cast(accum[15])); } // Except for the last warp-tile, all warp-tiles issue their share of @@ -647,13 +740,13 @@ class Wint2xMmaMultistage : // The last warp-tile also converts the shared memory fragments used by // the first warp-tile of the next iteration, if necessary (so we can // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - } + // if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + // warp_mma_.transform( + // pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], + // pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + // pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], + // pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + // } } } @@ -666,6 +759,18 @@ class Wint2xMmaMultistage : IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B) { +#if 0 + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + for (int i = 0; i < 3 * smem_n; ++i) { + for (int j = 0; j < smem_k; ++j) { + if (i % 3 == 0) { + CUTLASS_TRACE_DEVICE(" [i=%d, j=%d, %dx%d] %d", i, j, smem_n, smem_k, static_cast(smem_ptr_B_[i * smem_k + j])); + } + } + } +#endif + PipeState pipe_state; // Disable global fetching if done with global fetch iterations @@ -682,38 +787,94 @@ class Wint2xMmaMultistage : this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); ++this->warp_tile_iterator_B_; - typename Dequantizer::FragmentLocalScale warp_frag_local_scale; - typename Dequantizer::FragmentCodeScale warp_frag_code_scale; - typename Dequantizer::FragmentCodeZp warp_frag_code_zp; - typename Dequantizer::FragmentSuperScale warp_frag_super_scale; - typename Dequantizer::FragmentOutOperand warp_frag_out; - - CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load"); - warp_dequantizer_.load(warp_frag_local_scale, - warp_frag_code_scale, - warp_frag_code_zp, - warp_frag_super_scale); - - CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant"); - warp_dequantizer_.dequantize(warp_frag_out, - pipe_state.warp_loaded_frag_B_[0], - warp_frag_local_scale, - warp_frag_code_scale, - warp_frag_code_zp, - warp_frag_super_scale); - - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); + if (PipeState::WarpLoadedFragmentA::kElements == 8) { + ElementA* warp_frag_A_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_A_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes", + static_cast(warp_frag_A_ptr[0]), static_cast(warp_frag_A_ptr[1]), + static_cast(warp_frag_A_ptr[2]), static_cast(warp_frag_A_ptr[3]), + static_cast(warp_frag_A_ptr[4]), static_cast(warp_frag_A_ptr[5]), + static_cast(warp_frag_A_ptr[6]), static_cast(warp_frag_A_ptr[7]), + sizeof_bits::value / 8); + } + if (PipeState::WarpLoadedFragmentB::kElements == 64) { + uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", + static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), + static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), + static_cast(reg_uint8_ptr[4]), static_cast(reg_uint8_ptr[5]), + static_cast(reg_uint8_ptr[6]), static_cast(reg_uint8_ptr[7]), + static_cast(reg_uint8_ptr[8]), static_cast(reg_uint8_ptr[9]), + static_cast(reg_uint8_ptr[10]), static_cast(reg_uint8_ptr[11]), + static_cast(reg_uint8_ptr[12]), static_cast(reg_uint8_ptr[13]), + static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), + sizeof_bits::value / 8); + } + + typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_(pipe_state.warp_loaded_frag_B_[0]); + if (TransformBAfterLDS::result_type::kElements == 64) { + CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits::value / 8); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[0]), static_cast(unpacked_frag_B[1]), + static_cast(unpacked_frag_B[2]), static_cast(unpacked_frag_B[3]), + static_cast(unpacked_frag_B[4]), static_cast(unpacked_frag_B[5]), + static_cast(unpacked_frag_B[6]), static_cast(unpacked_frag_B[7]), + static_cast(unpacked_frag_B[8]), static_cast(unpacked_frag_B[9]), + static_cast(unpacked_frag_B[10]), static_cast(unpacked_frag_B[11]), + static_cast(unpacked_frag_B[12]), static_cast(unpacked_frag_B[13]), + static_cast(unpacked_frag_B[14]), static_cast(unpacked_frag_B[15])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[16]), static_cast(unpacked_frag_B[17]), + static_cast(unpacked_frag_B[18]), static_cast(unpacked_frag_B[19]), + static_cast(unpacked_frag_B[20]), static_cast(unpacked_frag_B[21]), + static_cast(unpacked_frag_B[22]), static_cast(unpacked_frag_B[23]), + static_cast(unpacked_frag_B[24]), static_cast(unpacked_frag_B[25]), + static_cast(unpacked_frag_B[26]), static_cast(unpacked_frag_B[27]), + static_cast(unpacked_frag_B[28]), static_cast(unpacked_frag_B[29]), + static_cast(unpacked_frag_B[30]), static_cast(unpacked_frag_B[31])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[32]), static_cast(unpacked_frag_B[33]), + static_cast(unpacked_frag_B[34]), static_cast(unpacked_frag_B[35]), + static_cast(unpacked_frag_B[36]), static_cast(unpacked_frag_B[37]), + static_cast(unpacked_frag_B[38]), static_cast(unpacked_frag_B[39]), + static_cast(unpacked_frag_B[40]), static_cast(unpacked_frag_B[41]), + static_cast(unpacked_frag_B[42]), static_cast(unpacked_frag_B[43]), + static_cast(unpacked_frag_B[44]), static_cast(unpacked_frag_B[45]), + static_cast(unpacked_frag_B[46]), static_cast(unpacked_frag_B[47])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[48]), static_cast(unpacked_frag_B[49]), + static_cast(unpacked_frag_B[50]), static_cast(unpacked_frag_B[51]), + static_cast(unpacked_frag_B[52]), static_cast(unpacked_frag_B[53]), + static_cast(unpacked_frag_B[54]), static_cast(unpacked_frag_B[55]), + static_cast(unpacked_frag_B[56]), static_cast(unpacked_frag_B[57]), + static_cast(unpacked_frag_B[58]), static_cast(unpacked_frag_B[59]), + static_cast(unpacked_frag_B[60]), static_cast(unpacked_frag_B[61]), + static_cast(unpacked_frag_B[62]), static_cast(unpacked_frag_B[63])); + } if (Detail::kStagedAccumulation) { pipe_state.tmp_accum_.clear(); + CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(pipe_state.tmp_accum_[0]), static_cast(pipe_state.tmp_accum_[1]), + static_cast(pipe_state.tmp_accum_[2]), static_cast(pipe_state.tmp_accum_[3]), + static_cast(pipe_state.tmp_accum_[4]), static_cast(pipe_state.tmp_accum_[5]), + static_cast(pipe_state.tmp_accum_[6]), static_cast(pipe_state.tmp_accum_[7]), + static_cast(pipe_state.tmp_accum_[8]), static_cast(pipe_state.tmp_accum_[9]), + static_cast(pipe_state.tmp_accum_[10]), static_cast(pipe_state.tmp_accum_[11]), + static_cast(pipe_state.tmp_accum_[12]), static_cast(pipe_state.tmp_accum_[13]), + static_cast(pipe_state.tmp_accum_[14]), static_cast(pipe_state.tmp_accum_[15])); + } else { + CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(accum[0]), static_cast(accum[1]), + static_cast(accum[2]), static_cast(accum[3]), + static_cast(accum[4]), static_cast(accum[5]), + static_cast(accum[6]), static_cast(accum[7]), + static_cast(accum[8]), static_cast(accum[9]), + static_cast(accum[10]), static_cast(accum[11]), + static_cast(accum[12]), static_cast(accum[13]), + static_cast(accum[14]), static_cast(accum[15])); } -#if 0 +#if 1 int stage = Base::kStages - 1; // Mainloop @@ -727,7 +888,7 @@ class Wint2xMmaMultistage : gemm_k_iterations, stage); stage += 1; - break; + // break; } if (Detail::kStagedAccumulation) { @@ -790,6 +951,8 @@ class Wint2xMmaMultistage : ///< initial value of accumulator FragmentC const &src_accum) { + ptr_B_ = reinterpret_cast(iterator_B.get_origin_pointer()); + // Prologue (start fetching iterations of global fragments into shared memory) prologue(iterator_A, iterator_B, gemm_k_iterations); @@ -800,7 +963,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h deleted file mode 100644 index 4d05740d53..0000000000 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h +++ /dev/null @@ -1,696 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations - targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" - -#include "cutlass/functional.h" -#include "cutlass/platform/platform.h" - -#include -#include "cutlass_extensions/weight_only_quant_op.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Matrix multiply operator - typename MmaOperator_, - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of Scale elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Number of threads participating in one matrix operation - int Threads, - /// - WeightOnlyQuantOp QuantOp_ = WeightOnlyQuantOp::UNDEFINED, - /// - typename Enable = void> -class MmaTensorOpWin2xDequantizer; - -//////////////////////////////////////////////////////////////////////////////// -// Bfloat specialization for Ampere -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> - -class MmaTensorOpWin2xDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - bfloat16_t, - layout::ColumnMajor, - 32, - QuantOp_, - typename platform::enable_if= - 70>::type> { - public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = - MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementWeight = uint2b_t; - - /// Type of the scales - using ElementUnzipWeight = uint8_t; - - /// Type of the scales - using ElementScale = bfloat16_t; - - /// Type of the scales - using ScaleComputeT = float; - - static constexpr int unzip_len = 4; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = - Array; - using FragmentWeightOperand = - Array; - using FragmentOutOperand = - Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentLocalScale = Array; - using FragmentCodeScale = Array; - using FragmentCodeZp = Array; - using FragmentSuperScale = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::ColumnMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = cutlass::TensorRef; - using TensorCodeRef = cutlass::TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, - TensorCodeRef smem_code_scale, - TensorCodeRef smem_code_zp, - TensorRef smem_super_scale, - int const warp_idx_n, - int const lane_idx) { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_local_scale_ = smem_local_scale.data() + thread_offset; - pointer_code_scale_ = smem_code_scale.data() + thread_offset; - pointer_code_zp_ = smem_code_zp.data() + thread_offset; - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = smem_super_scale.data() + thread_offset; - } - } - - // CUTLASS_DEVICE - // MmaTensorOpWin2xDequantizer() { - // pointer_local_scale_ = nullptr; - // pointer_code_scale_ = nullptr; - // pointer_code_zp_ = nullptr; - // if constexpr (hasZero(QuantOp)) { - // pointer_super_scale_ = nullptr; - // } - // } - - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer() { - // Create fake pointer using a shared dummy buffer - CUTLASS_TRACE_DEVICE(" warp dequant aaa"); - - extern __shared__ char cutlass_fake_dequant_smem[]; - - // Memory layout (manual alignment): - // ElementScale (half or bf16): 2 bytes - // ScaleComputeT (float): 4 bytes - - pointer_local_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem); - pointer_code_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem + 64); - pointer_code_zp_ = - reinterpret_cast(cutlass_fake_dequant_smem + 128); - - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = reinterpret_cast( - cutlass_fake_dequant_smem + 192); - } - } - - CUTLASS_DEVICE - void load(FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { - CUTLASS_TRACE_DEVICE(" warp dequant load"); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * - // InstructionShape::kN]; code_scale_frag[mma_n_iter] = - // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; - // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * - // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) - // { - // super_scale_frag[mma_n_iter] = - // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; - // } - // } - } - - CUTLASS_DEVICE - void dequantize(FragmentOutOperand& out_frag, - FragmentDequantizedOperand& operand_frag, - FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - CUTLASS_TRACE_DEVICE(" dequantize if def"); - - static constexpr int32_t kGroupSize = 64; - static constexpr int32_t kPackNum = 4; - static constexpr int32_t kWeightMask = 0x3F; - static constexpr int32_t kLocalScaleMask = 0xF; - static constexpr int32_t kBZP = 32; - - // using _MmaOperandB = typename ArchMmaOperator::FragmentB; - // using ExpandedMmaOperandB = Array; - // static_assert(ExpandedMmaOperandB::kElements * - // MmaOperator::MmaIterations::kColumn - // == FragmentDequantizedOperand::kElements, - // ""); - - // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // MmaOperator::MmaIterations::kColumn); - - // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" - // FragmentDequantizedOperand::kElements = %d ", - // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" - // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); - - // FragmentWeightOperand - CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", - FragmentWeightOperand::kElements); - // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", - // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight - // size = %d ", sizeof(ElementWeight)); - static_assert(std::is_same::value, - "B 是 uint8 量化类型"); - FragmentWeightOperand* weight_ptr = - reinterpret_cast(&operand_frag); - FragmentLocalScale* local_scale_ptr = - reinterpret_cast(&local_scale_frag); - FragmentCodeScale* code_scale_ptr = - reinterpret_cast(&code_scale_frag); - FragmentCodeZp* code_zp_ptr = - reinterpret_cast(&code_zp_frag); - FragmentSuperScale* super_scale_ptr = - reinterpret_cast(&super_scale_frag); - - ScaleComputeT code_scale = - static_cast(code_scale_ptr[0][0]); - ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); - ScaleComputeT super_scale = - static_cast(super_scale_ptr[0][0]); - int32_t local_scale = static_cast(local_scale_ptr[0][0]); - int32_t const shift_bits[4] = {9, 6, 3, 0}; - - ScaleComputeT zipped_value[16]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - zipped_value[i] = static_cast(weight_ptr[0][i]); - } - - int local_scale_shift = 4; - int32_t shifted_local_scale = - (local_scale >> local_scale_shift) & kLocalScaleMask; - ScaleComputeT scale = - static_cast(shifted_local_scale) * super_scale; - -#pragma unroll - for (int i = 0; i < 16; ++i) { - int32_t decode_value = static_cast( - floor(zipped_value[i] * code_scale + code_zp + - static_cast(0.5))); - - int col = i * 4; - -#pragma unroll - for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { - int32_t shift_bit = shift_bits[shift_bit_id]; - int32_t shifted_value = - (decode_value >> shift_bit) & kWeightMask; - - ScaleComputeT value = - static_cast(shifted_value - kBZP); - out_frag[col + shift_bit_id] = - static_cast(scale * value); - } - } - - CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", - kColsPerMmaPerThread); - CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", - MmaOperator::MmaIterations::kColumn); - - // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 - // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = - // reinterpret_cast(&operand_frag); - - // printf("threadidx.x = %d\n", threadIdx.x); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - // __nv_bfloat162 scalex2 = - // __bfloat162bfloat162(scale_ptr[mma_n_iter]); - // __nv_bfloat162* operand_bf16x2_ptr = - // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - // CUTLASS_PRAGMA_UNROLL - // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - // { - // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], - // scalex2); - // } - // } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on - // older arch, scale conversion should happen before scales are stored - // to shared memory and we should use the fp16 dequantizer. This will - // avoid numerous conversion instructions in GEMM main loop. - CUTLASS_TRACE_DEVICE(" dequantize else def"); - // arch::device_breakpoint(); -#endif - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_local_scale_ += offset; - pointer_code_scale_ += offset; - pointer_code_zp_ += offset; - pointer_super_scale_ += offset; - } - - private: - ElementScale const* pointer_local_scale_; - ScaleComputeT const* pointer_code_scale_; - ScaleComputeT const* pointer_code_zp_; - ElementScale const* pointer_super_scale_; - - ElementScale const* pointer_out_; -}; - -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> -class MmaTensorOpWin2xDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - half_t, - layout::ColumnMajor, - 32, - QuantOp_, - typename platform::enable_if= - 70>::type> { - public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = - MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementWeight = uint2b_t; - - /// Type of the scales - using ElementUnzipWeight = uint8_t; - - /// Type of the scales - using ElementScale = half_t; - - /// Type of the scales - using ScaleComputeT = float; - - static constexpr int unzip_len = 4; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = - Array; - using FragmentWeightOperand = - Array; - using FragmentOutOperand = - Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentLocalScale = Array; - using FragmentCodeScale = Array; - using FragmentCodeZp = Array; - using FragmentSuperScale = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::ColumnMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = cutlass::TensorRef; - using TensorCodeRef = cutlass::TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, - TensorCodeRef smem_code_scale, - TensorCodeRef smem_code_zp, - TensorRef smem_super_scale, - int const warp_idx_n, - int const lane_idx) { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_local_scale_ = smem_local_scale.data() + thread_offset; - pointer_code_scale_ = smem_code_scale.data() + thread_offset; - pointer_code_zp_ = smem_code_zp.data() + thread_offset; - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = smem_super_scale.data() + thread_offset; - } - } - - // CUTLASS_DEVICE - // MmaTensorOpWin2xDequantizer() { - // pointer_local_scale_ = nullptr; - // pointer_code_scale_ = nullptr; - // pointer_code_zp_ = nullptr; - // if constexpr (hasZero(QuantOp)) { - // pointer_super_scale_ = nullptr; - // } - // } - - CUTLASS_DEVICE - MmaTensorOpWin2xDequantizer() { - // Create fake pointer using a shared dummy buffer - CUTLASS_TRACE_DEVICE(" warp dequant aaa"); - - extern __shared__ char cutlass_fake_dequant_smem[]; - - // Memory layout (manual alignment): - // ElementScale (half or bf16): 2 bytes - // ScaleComputeT (float): 4 bytes - - pointer_local_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem); - pointer_code_scale_ = - reinterpret_cast(cutlass_fake_dequant_smem + 64); - pointer_code_zp_ = - reinterpret_cast(cutlass_fake_dequant_smem + 128); - - if constexpr (hasZero(QuantOp)) { - pointer_super_scale_ = reinterpret_cast( - cutlass_fake_dequant_smem + 192); - } - } - - CUTLASS_DEVICE - void load(FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { - CUTLASS_TRACE_DEVICE(" warp dequant load"); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * - // InstructionShape::kN]; code_scale_frag[mma_n_iter] = - // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; - // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * - // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) - // { - // super_scale_frag[mma_n_iter] = - // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; - // } - // } - } - - CUTLASS_DEVICE - void dequantize(FragmentOutOperand& out_frag, - FragmentDequantizedOperand& operand_frag, - FragmentLocalScale& local_scale_frag, - FragmentCodeScale& code_scale_frag, - FragmentCodeZp& code_zp_frag, - FragmentSuperScale& super_scale_frag) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - CUTLASS_TRACE_DEVICE(" dequantize if def"); - - static constexpr int32_t kGroupSize = 64; - static constexpr int32_t kPackNum = 4; - static constexpr int32_t kWeightMask = 0x3F; - static constexpr int32_t kLocalScaleMask = 0xF; - static constexpr int32_t kBZP = 32; - - // using _MmaOperandB = typename ArchMmaOperator::FragmentB; - // using ExpandedMmaOperandB = Array; - // static_assert(ExpandedMmaOperandB::kElements * - // MmaOperator::MmaIterations::kColumn - // == FragmentDequantizedOperand::kElements, - // ""); - - // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // MmaOperator::MmaIterations::kColumn); - - // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", - // MmaOperator::IteratorB::InstructionShape::kRow, - // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" - // FragmentDequantizedOperand::kElements = %d ", - // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" - // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); - - // FragmentWeightOperand - CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", - FragmentWeightOperand::kElements); - // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", - // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight - // size = %d ", sizeof(ElementWeight)); - static_assert(std::is_same::value, - "B 是 uint8 量化类型"); - FragmentWeightOperand* weight_ptr = - reinterpret_cast(&operand_frag); - FragmentLocalScale* local_scale_ptr = - reinterpret_cast(&local_scale_frag); - FragmentCodeScale* code_scale_ptr = - reinterpret_cast(&code_scale_frag); - FragmentCodeZp* code_zp_ptr = - reinterpret_cast(&code_zp_frag); - FragmentSuperScale* super_scale_ptr = - reinterpret_cast(&super_scale_frag); - - ScaleComputeT code_scale = - static_cast(code_scale_ptr[0][0]); - ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); - ScaleComputeT super_scale = - static_cast(super_scale_ptr[0][0]); - int32_t local_scale = static_cast(local_scale_ptr[0][0]); - int32_t const shift_bits[4] = {9, 6, 3, 0}; - - ScaleComputeT zipped_value[16]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - zipped_value[i] = static_cast(weight_ptr[0][i]); - } - - int local_scale_shift = 4; - int32_t shifted_local_scale = - (local_scale >> local_scale_shift) & kLocalScaleMask; - ScaleComputeT scale = - static_cast(shifted_local_scale) * super_scale; - -#pragma unroll - for (int i = 0; i < 16; ++i) { - int32_t decode_value = static_cast( - floor(zipped_value[i] * code_scale + code_zp + - static_cast(0.5))); - - int col = i * 4; - -#pragma unroll - for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { - int32_t shift_bit = shift_bits[shift_bit_id]; - int32_t shifted_value = - (decode_value >> shift_bit) & kWeightMask; - - ScaleComputeT value = - static_cast(shifted_value - kBZP); - out_frag[col + shift_bit_id] = - static_cast(scale * value); - } - } - - CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", - kColsPerMmaPerThread); - CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", - MmaOperator::MmaIterations::kColumn); - - // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 - // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = - // reinterpret_cast(&operand_frag); - - // printf("threadidx.x = %d\n", threadIdx.x); - // CUTLASS_PRAGMA_UNROLL - // for (int mma_n_iter = 0; mma_n_iter < - // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - // { - // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - // __nv_bfloat162 scalex2 = - // __bfloat162bfloat162(scale_ptr[mma_n_iter]); - // __nv_bfloat162* operand_bf16x2_ptr = - // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - // CUTLASS_PRAGMA_UNROLL - // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - // { - // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], - // scalex2); - // } - // } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on - // older arch, scale conversion should happen before scales are stored - // to shared memory and we should use the fp16 dequantizer. This will - // avoid numerous conversion instructions in GEMM main loop. - CUTLASS_TRACE_DEVICE(" dequantize else def"); - // arch::device_breakpoint(); -#endif - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_local_scale_ += offset; - pointer_code_scale_ += offset; - pointer_code_zp_ += offset; - pointer_super_scale_ += offset; - } - - private: - ElementScale const* pointer_local_scale_; - ScaleComputeT const* pointer_code_scale_; - ScaleComputeT const* pointer_code_zp_; - ElementScale const* pointer_super_scale_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680e..11b70ffd89 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -39,18 +39,16 @@ #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" +#include "cutlass/trace.h" -namespace cutlass -{ +namespace cutlass { // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; +struct FastInterleavedAndBiasedNumericArrayConverter; template <> struct FastInterleavedAndBiasedNumericArrayConverter @@ -440,6 +438,135 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = T; + + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kBZP = 32; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + uint8_t const* in_ptr = reinterpret_cast(&source); + + ScaleComputeT code_scale = static_cast(1); + ScaleComputeT code_zp = static_cast(0); + ScaleComputeT floor_offset = static_cast(0.5); + + // CUTLASS_TRACE_DEVICE_TID(" source: [%d, %d, %d, %d]", + // static_cast(in_ptr[0]), static_cast(in_ptr[1]), + // static_cast(in_ptr[2]), static_cast(in_ptr[3])); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + int32_t decode_value = + static_cast(floor(static_cast(in_ptr[i]) * code_scale + code_zp + floor_offset)); + + ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + + result[i * 4] = static_cast(value_0); + result[i * 4 + 1] = static_cast(value_1); + result[i * 4 + 2] = static_cast(value_2); + result[i * 4 + 3] = static_cast(value_3); + } + + // 预定义的固定值数组(64个元素) + const int fixed_values[64] = { + 0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57, + 2, 3, 10, 11, 18, 19, 26, 27, 34, 35, 42, 43, 50, 51, 58, 59, + 4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61, + 6, 7, 14, 15, 22, 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63 + }; + + // CUTLASS_PRAGMA_UNROLL + // for (int i = 0; i < 16; ++i) { + // // result[i] = static_cast(fixed_values[i + idx * 16]); + // if (threadIdx.x % 32 == 0 || threadIdx.x % 32 == 4) { + // result[i] = static_cast(fixed_values0[i + idx * 16]); + // } else if (threadIdx.x % 32 == 1 || threadIdx.x % 32 == 5) { + // result[i] = static_cast(fixed_values1[i + idx * 16]); + // } else if (threadIdx.x % 32 == 2 || threadIdx.x % 32 == 6) { + // result[i] = static_cast(fixed_values2[i + idx * 16]); + // } else if (threadIdx.x % 32 == 3 || threadIdx.x % 32 == 7) { + // result[i] = static_cast(fixed_values3[i + idx * 16]); + // } else { + // result[i] = static_cast(0); + // } + // } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 16; ++i) { + result[i] = static_cast(fixed_values[i + (threadIdx.x % 4) * 16]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index ce20bcaaec..0d9aa62b3f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -861,13 +861,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); - if (problem_idx > 2) { break; } + CUTLASS_TRACE_DEVICE(" problem_idx: %d, cta_idx: %d, problem_size: {%d, %d, %d}", + problem_idx, cta_idx, static_cast(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); // threadblock_offset of C @@ -919,6 +919,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(ldm_B), tb_offset_B.row(), tb_offset_B.column(), extent_B.row(), extent_B.column()); + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); // Compute position within threadblock From 0fabdbc8887363e746c9b75d35bfdc92f90c15ae Mon Sep 17 00:00:00 2001 From: liuyiqun01 Date: Tue, 15 Jul 2025 11:06:29 +0800 Subject: [PATCH 10/11] Use async copy for local_scale. Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893 --- .../gemm/threadblock/default_wint2x_mma.h | 14 ++--- .../gemm/threadblock/wint2x_mma_multistage.h | 1 - .../gemm/threadblock/wint2x_params_accessor.h | 53 +++++++++++-------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h index 8f4b1efefb..4209c3029a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -47,11 +47,6 @@ struct DefaultQuantParamsIterators { MatrixShape, ElementT, layout::RowMajor, 0, IteratorThreadMap, kAlignment>; using SmemIterator = Iterator; - - //using AccessType = cutlass::Array; - //using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator< - // MatrixShape, ElementT, layout::RowMajor, - // 0, IteratorThreadMap, AccessType>; }; template @@ -70,10 +65,11 @@ struct DefaultQuantParamsIterators { kColumns / kAlignment, kAlignment>; public: - using Iterator = - cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, uint4b_t, - layout::RowMajor, 0, IteratorThreadMap, kAlignment>; + using AccessType = cutlass::Array; + using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator< + MatrixShape, uint4b_t, layout::RowMajor, + 0, IteratorThreadMap, AccessType>; + using SmemIterator = Iterator; }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index a006174811..b7cc0c9951 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -572,7 +572,6 @@ class Wint2xMmaMultistage : ++this->smem_iterator_B_; } - __syncthreads(); } /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h index 194c06cf5d..5eafd3534e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h @@ -14,6 +14,7 @@ #pragma once +#include "cutlass/arch/memory_sm80.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_shape.h" @@ -67,7 +68,7 @@ class Wint2ParamsAccessor { using ElementSuperScale = typename IteratorSuperScale::Element; using LayoutSuperScale = typename IteratorSuperScale::Layout; - // local_scale uint4 and group-wise + /// local_scale uint4 and group-wise using ElementLocalScale = typename IteratorLocalScale::Element; using LayoutLocalScale = typename IteratorLocalScale::Layout; static_assert(platform::is_same::value, @@ -76,7 +77,7 @@ class Wint2ParamsAccessor { using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element; using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout; - // 2 uint4b_t values are stored in a single uint8_t + /// 2 uint4b_t values are stored in a single uint8_t constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK; constexpr static int kLocalScaleRows = IteratorLocalScale::Shape::kRow; @@ -249,29 +250,37 @@ class Wint2ParamsAccessor { if ((stage % kStagesPerLocalScaleLoad) == 0) { // Load group-wise local_scale to shared memory, which only needs to be done at each stage. // Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages. - typename IteratorLocalScale::Fragment tb_frag_local_scale; - tb_frag_local_scale.clear(); - quant_args.iterator_local_scale.load(tb_frag_local_scale); - this->smem_iterator_local_scale_.store(tb_frag_local_scale); + using AccessType = typename IteratorLocalScale::AccessType; + cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits::value == 128) + ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; + + quant_args.iterator_local_scale.set_iteration_index(0); + this->smem_iterator_local_scale_.set_iteration_index(0); + + // Async Copy for local_scale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) { + AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_local_scale_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) { + auto gmem_ptr = quant_args.iterator_local_scale.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorLocalScale::ThreadMap::kElementsPerAccess / + IteratorLocalScale::kAccessesPerVector / 8; + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid()); + } + ++quant_args.iterator_local_scale; + } + ++this->smem_iterator_local_scale_; //CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] Shape: {%d, %d}", // stage, IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn); -#if 0 - __syncthreads(); - if (IteratorLocalScale::Fragment::kElements == 32) { - uint8_t* local_scale_ptr = reinterpret_cast(tb_frag_local_scale.data()); - CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] tb_frag_local_scale[0:15]=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d]", - stage, - static_cast(local_scale_ptr[0]), static_cast(local_scale_ptr[1]), - static_cast(local_scale_ptr[2]), static_cast(local_scale_ptr[3]), - static_cast(local_scale_ptr[4]), static_cast(local_scale_ptr[5]), - static_cast(local_scale_ptr[6]), static_cast(local_scale_ptr[7]), - static_cast(local_scale_ptr[8]), static_cast(local_scale_ptr[9]), - static_cast(local_scale_ptr[10]), static_cast(local_scale_ptr[11]), - static_cast(local_scale_ptr[12]), static_cast(local_scale_ptr[13]), - static_cast(local_scale_ptr[14]), static_cast(local_scale_ptr[15])); - } -#endif } } From a29ff4124e974173be079e656b35ccdaff7d6d2d Mon Sep 17 00:00:00 2001 From: baoqiwen Date: Tue, 15 Jul 2025 19:49:05 +0800 Subject: [PATCH 11/11] add pingpong buffer for b_frag --- .../gemm/threadblock/wint2x_mma_multistage.h | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index ed632f2e39..b3a9546014 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -209,7 +209,7 @@ class Wint2xMmaMultistage : WarpTransformedFragmentA warp_frag_A_[2]; /// Pair of B fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentB warp_loaded_frag_B_; + WarpLoadedFragmentB warp_loaded_frag_B_[2]; WarpTransformedFragmentB warp_frag_B_; }; @@ -691,10 +691,10 @@ class Wint2xMmaMultistage : int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB; - if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { + if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { // Load the next warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k_for_B + 1) % 2]); ++this->warp_tile_iterator_B_; warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); @@ -718,6 +718,16 @@ class Wint2xMmaMultistage : // static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), // sizeof_bits::value / 8); + if (warp_k_compute_offset_B == 0) { + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_[warp_mma_k_for_B % 2], + pipe_state.warp_frag_B_, + (stage - Base::kStages + 2) * Shape::kK); + } + if (Detail::kStagedAccumulation) { //CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B); warp_mma_( @@ -814,16 +824,6 @@ class Wint2xMmaMultistage : iterator_B.clear_mask(gemm_k_iterations == 0); quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); } - - if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { - warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, - pipe_state.warp_frag_code_scale_, - pipe_state.warp_frag_code_zp_, - pipe_state.warp_frag_super_scale_, - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_frag_B_, - (stage - Base::kStages + 2) * Shape::kK); - } } } @@ -861,7 +861,7 @@ class Wint2xMmaMultistage : // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); ++this->warp_tile_iterator_B_; #if 0 @@ -907,14 +907,6 @@ class Wint2xMmaMultistage : } #endif - warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, - pipe_state.warp_frag_code_scale_, - pipe_state.warp_frag_code_zp_, - pipe_state.warp_frag_super_scale_, - pipe_state.warp_loaded_frag_B_, - pipe_state.warp_frag_B_, - 0); - #if 0 if (TransformBAfterLDS::result_type::kElements == 64) { CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits::value / 8);