diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 40f128b7a0..8f61c6d9c4 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -133,10 +133,18 @@ template template struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; // 64 + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8 + +public: + // using Layout = layout::ColumnMajor; + // static constexpr int ElementsPerAccess = 16; // at least 4-bytes + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; // 64 + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index bc395d04db..b50d66380e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -18,14 +18,12 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -378,38 +376,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -441,38 +424,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 5d2c311704..300261c3f0 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -19,7 +19,7 @@ #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" namespace cutlass { namespace gemm { @@ -379,38 +379,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -442,38 +427,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h new file mode 100644 index 0000000000..72c22a175f --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -0,0 +1,174 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +struct DefaultWint2xMma; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DefaultWint2xMma +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Element B must be uint2b_t"); + + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved"); + static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); + + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; + static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); + + using IteratorShapeB = MatrixShape< + MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>; + using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + ThreadMapB::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB, + AccessTypeB>; + + using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< + ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< + typename MmaCore::Shape, + IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, + ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 6dd55b647a..cdb465c60c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -93,6 +93,15 @@ class Wint2xMmaBase { static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + /// Number of warp-level GEMM oeprations per load for B + static constexpr int kWarpGemmIterationsPerLoadForB = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); + + static constexpr int kWarpLoadIterationsForB = + kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; + + /// Number of stages static int const kStages = Stages; @@ -104,8 +113,6 @@ class Wint2xMmaBase { using TensorRefB = TensorRef; - // using TensorRefZippedB = TensorRef; - static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " "GEMM operations."); @@ -130,20 +137,19 @@ class Wint2xMmaBase { Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = MatrixShape; - // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = - Shape::kK / 4 + (Shape::kK + 127) / 128; + // local_scale uint4 + constexpr static int kGroupWiseParamRows = Shape::kK / 64; + + using GroupWiseParamShapeB = MatrixShape; // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + + constexpr static int kColumnWiseParamRows = 2 * sizeof(float) + sizeof_bits::value / 8; - using ZippedShapeB = MatrixShape; - - using NopaddingShapeB = MatrixShape; + using ColumnWiseParamShapeB = MatrixShape; public: // @@ -156,12 +162,11 @@ class Wint2xMmaBase { /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for quanted B operand - AlignedBuffer operand_zipped_B; + /// Buffer for local_scale of B operand + AlignedBuffer operand_local_scale_B; - /// Buffer for unzip B operand - AlignedBuffer - operand_unzip_B; + /// Buffer for column-wise params of B operand + AlignedBuffer operand_column_wise_B; public: // @@ -191,14 +196,6 @@ class Wint2xMmaBase { TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - - CUTLASS_HOST_DEVICE - uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } - - CUTLASS_HOST_DEVICE - typename Operator::ElementB *operand_unzip_B_ptr() { - return operand_unzip_B.data(); - } }; protected: diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 38fdcf9fec..ca63f8c1d6 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -44,8 +44,8 @@ #include "cutlass/numeric_types.h" #include "cutlass_extensions/arch/memory_copy_sm80.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,11 +86,11 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Transform for input B applied in register after the LDS + typename TransformBAfterLDS_, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class Wint2xMmaMultistage : + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +class Wint2xMmaMultistage : public Wint2xMmaBase { public: ///< Base class @@ -107,8 +107,10 @@ class Wint2xMmaMultistage : using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; + /// Transform for input B applied in register after the LDS + using TransformBAfterLDS = TransformBAfterLDS_; - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; using SmemIteratorA = SmemIteratorA_; using SmemIteratorB = SmemIteratorB_; @@ -129,6 +131,18 @@ class Wint2xMmaMultistage : /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; + using LayoutScale = cutlass::layout::ColumnMajor; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using Dequantizer = + warp::MmaTensorOpWin2xDequantizer; + static_assert(sizeof(Dequantizer) > 0, "Dequantizer template instantiation failed"); + /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -186,6 +200,14 @@ class Wint2xMmaMultistage : WarpTransformedFragmentB warp_transformed_frag_B_[2]; }; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool IsTileInterleaveLayout = + layout::IsColumnMajorTileInterleave::value; + static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); private: @@ -196,6 +218,9 @@ class Wint2xMmaMultistage : /// Warp-level MMA operator Operator warp_mma_; + // Wint2 unzip operator + Dequantizer warp_dequantizer_; + /// Iterator to write threadblock-scoped tile of A operand to shared memory SmemIteratorA smem_iterator_A_; @@ -208,10 +233,11 @@ class Wint2xMmaMultistage : /// Shared memory read stage index int smem_read_stage_idx_; - uint8_t* column_wise_smem_ptr_B_; + /// Transform for B in register + TransformBAfterLDS transform_B_; - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; + uint8_t* smem_ptr_B_; + uint8_t* ptr_B_; public: @@ -245,16 +271,31 @@ class Wint2xMmaMultistage : int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d", + Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave); + CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d", + Policy::kPartitionsK, Base::kWarpGemmIterations, + Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k); + // Add per-warp offsets in units of warp-level tiles this->warp_tile_iterator_A_.add_tile_offset( {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); this->warp_tile_iterator_B_.add_tile_offset( {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; + CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}", + Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn); + CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d", + shared_storage.operand_A.data(), static_cast(Base::SharedStorage::ShapeA::kRow), + static_cast(Base::SharedStorage::ShapeA::kColumn)); + CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVector=%d", + shared_storage.operand_B.data(), + static_cast(Base::SharedStorage::ShapeB::kRow), static_cast(Base::SharedStorage::ShapeB::kColumn), + static_cast(sizeof(shared_storage.operand_B)), + static_cast(IteratorB::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorB::AccessType)), + static_cast(Detail::AsyncCopyIterationsPerStageB), static_cast(IteratorB::kAccessesPerVector)); + + smem_ptr_B_ = reinterpret_cast(shared_storage.operand_B.data()); } /// Advance shared memory read-iterators to the next stage @@ -266,28 +307,22 @@ class Wint2xMmaMultistage : if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - // this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); smem_read_stage_idx_ = 0; } - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); } /// Advance global memory read-iterators and shared memory write-iterators to the stage - template CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) + void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); - //iterator_B.add_tile_offset({1, 0}); - tile_dequanter_B.AddTileOffset({1, 0}); + iterator_B.add_tile_offset({1, 0}); // Advance shared iterators smem_iterator_A_.add_tile_offset({0, 1}); - //smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_B_.add_tile_offset({1, 0}); // Increment shared memory write stage index ++smem_write_stage_idx_; @@ -295,7 +330,7 @@ class Wint2xMmaMultistage : if (smem_write_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - //smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); smem_write_stage_idx_ = 0; } } @@ -361,6 +396,13 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); + if (group_start_B == 0 && j == 0 && v == 0) { + CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d", + reinterpret_cast(dst_ptr), reinterpret_cast(gmem_ptr), + static_cast(Detail::kAccessesPerGroupB), static_cast(IteratorB::kAccessesPerVector), + static_cast(sizeof(typename IteratorB::Element))); + } + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { cutlass::arch::copy_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); @@ -413,7 +455,7 @@ class Wint2xMmaMultistage : template CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B, int stage) { iterator_B.set_iteration_index(0); this->smem_iterator_B_.set_iteration_index(0); @@ -433,6 +475,31 @@ class Wint2xMmaMultistage : IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + if (v == 0) { + int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_B_); + int gmem_k = 8192 * kInterleave / 4; + int gmem_n = 1792 / kInterleave; + int gmem_row = gmem_offset / gmem_k; + int gmem_col = gmem_offset % gmem_k; + + int smem_offset = reinterpret_cast(dst_ptr) - reinterpret_cast(smem_ptr_B_); + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + int smem_row = smem_offset / smem_k; + int smem_col = smem_offset % smem_k; + + uint8_t* gmem_uint8_ptr = reinterpret_cast(gmem_ptr); + + CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};", + stage, reinterpret_cast(gmem_ptr), reinterpret_cast(dst_ptr), kSrcBytes, + gmem_n, gmem_k, gmem_row, gmem_col, + static_cast(gmem_uint8_ptr[0]), static_cast(gmem_uint8_ptr[1]), + static_cast(gmem_uint8_ptr[2]), static_cast(gmem_uint8_ptr[3]), + static_cast(gmem_uint8_ptr[4]), static_cast(gmem_uint8_ptr[5]), + static_cast(gmem_uint8_ptr[6]), static_cast(gmem_uint8_ptr[7]), + smem_row, smem_col); + } + if (InitStage) { cutlass::arch::copy_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); @@ -456,12 +523,10 @@ class Wint2xMmaMultistage : /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - template CUTLASS_DEVICE void prologue( IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages @@ -476,11 +541,12 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + copy_tiles_and_advance_per_stage_B(iterator_B, stage); + + // TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); @@ -542,14 +608,12 @@ class Wint2xMmaMultistage : } /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - template CUTLASS_DEVICE void mac_loop_iter( PipeState &pipe_state, ///< [in|out] loop-carried pipeline state FragmentC &accum, ///< [in|out] destination accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining int stage) { @@ -563,16 +627,6 @@ class Wint2xMmaMultistage : this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - // Unpack and dequant the first stage of B. - int unpack_stage = stage - Base::kStages + 2; - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, unpack_stage); - - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); - } - // Load the next warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); @@ -614,13 +668,10 @@ class Wint2xMmaMultistage : // global->shared fragment copies if (warp_mma_k < Base::kWarpGemmIterations - 1) { int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - - if (warp_mma_k == 0) { - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); - } + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); } // The second-to-last warp-tile also: @@ -629,8 +680,10 @@ class Wint2xMmaMultistage : if (warp_mma_k + 2 == Base::kWarpGemmIterations) { // Performs the last warp-tile's share of global->shared fragment copies int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); @@ -639,13 +692,13 @@ class Wint2xMmaMultistage : gmem_wait(); // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); advance_smem_read_stage(); // Disable global fetching when done with global fetch iterations --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); + iterator_B.clear_mask(gemm_k_iterations == 0); } // The last warp-tile also converts the shared memory fragments used by @@ -663,19 +716,26 @@ class Wint2xMmaMultistage : /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. - template CUTLASS_DEVICE void gemm_iters( int gemm_k_iterations, ///< number of threadblock mainloop iterations FragmentC &accum, ///< [in|out] accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory + IteratorB &iterator_B) { - PipeState pipe_state; +#if 0 + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + for (int i = 0; i < 3 * smem_n; ++i) { + for (int j = 0; j < smem_k; ++j) { + if (i % 3 == 0) { + CUTLASS_TRACE_DEVICE(" [i=%d, j=%d, %dx%d] %d", i, j, smem_n, smem_k, static_cast(smem_ptr_B_[i * smem_k + j])); + } + } + } +#endif - // Unpack and dequant the first stage of B. - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); + PipeState pipe_state; // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); @@ -686,14 +746,96 @@ class Wint2xMmaMultistage : this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); ++this->warp_tile_iterator_A_; - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); - // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); ++this->warp_tile_iterator_B_; + if (PipeState::WarpLoadedFragmentA::kElements == 8) { + ElementA* warp_frag_A_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_A_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes", + static_cast(warp_frag_A_ptr[0]), static_cast(warp_frag_A_ptr[1]), + static_cast(warp_frag_A_ptr[2]), static_cast(warp_frag_A_ptr[3]), + static_cast(warp_frag_A_ptr[4]), static_cast(warp_frag_A_ptr[5]), + static_cast(warp_frag_A_ptr[6]), static_cast(warp_frag_A_ptr[7]), + sizeof_bits::value / 8); + } + if (PipeState::WarpLoadedFragmentB::kElements == 64) { + uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", + static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), + static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), + static_cast(reg_uint8_ptr[4]), static_cast(reg_uint8_ptr[5]), + static_cast(reg_uint8_ptr[6]), static_cast(reg_uint8_ptr[7]), + static_cast(reg_uint8_ptr[8]), static_cast(reg_uint8_ptr[9]), + static_cast(reg_uint8_ptr[10]), static_cast(reg_uint8_ptr[11]), + static_cast(reg_uint8_ptr[12]), static_cast(reg_uint8_ptr[13]), + static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), + sizeof_bits::value / 8); + } + + typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_(pipe_state.warp_loaded_frag_B_[0]); + if (TransformBAfterLDS::result_type::kElements == 64) { + CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits::value / 8); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[0]), static_cast(unpacked_frag_B[1]), + static_cast(unpacked_frag_B[2]), static_cast(unpacked_frag_B[3]), + static_cast(unpacked_frag_B[4]), static_cast(unpacked_frag_B[5]), + static_cast(unpacked_frag_B[6]), static_cast(unpacked_frag_B[7]), + static_cast(unpacked_frag_B[8]), static_cast(unpacked_frag_B[9]), + static_cast(unpacked_frag_B[10]), static_cast(unpacked_frag_B[11]), + static_cast(unpacked_frag_B[12]), static_cast(unpacked_frag_B[13]), + static_cast(unpacked_frag_B[14]), static_cast(unpacked_frag_B[15])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[16]), static_cast(unpacked_frag_B[17]), + static_cast(unpacked_frag_B[18]), static_cast(unpacked_frag_B[19]), + static_cast(unpacked_frag_B[20]), static_cast(unpacked_frag_B[21]), + static_cast(unpacked_frag_B[22]), static_cast(unpacked_frag_B[23]), + static_cast(unpacked_frag_B[24]), static_cast(unpacked_frag_B[25]), + static_cast(unpacked_frag_B[26]), static_cast(unpacked_frag_B[27]), + static_cast(unpacked_frag_B[28]), static_cast(unpacked_frag_B[29]), + static_cast(unpacked_frag_B[30]), static_cast(unpacked_frag_B[31])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[32]), static_cast(unpacked_frag_B[33]), + static_cast(unpacked_frag_B[34]), static_cast(unpacked_frag_B[35]), + static_cast(unpacked_frag_B[36]), static_cast(unpacked_frag_B[37]), + static_cast(unpacked_frag_B[38]), static_cast(unpacked_frag_B[39]), + static_cast(unpacked_frag_B[40]), static_cast(unpacked_frag_B[41]), + static_cast(unpacked_frag_B[42]), static_cast(unpacked_frag_B[43]), + static_cast(unpacked_frag_B[44]), static_cast(unpacked_frag_B[45]), + static_cast(unpacked_frag_B[46]), static_cast(unpacked_frag_B[47])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[48]), static_cast(unpacked_frag_B[49]), + static_cast(unpacked_frag_B[50]), static_cast(unpacked_frag_B[51]), + static_cast(unpacked_frag_B[52]), static_cast(unpacked_frag_B[53]), + static_cast(unpacked_frag_B[54]), static_cast(unpacked_frag_B[55]), + static_cast(unpacked_frag_B[56]), static_cast(unpacked_frag_B[57]), + static_cast(unpacked_frag_B[58]), static_cast(unpacked_frag_B[59]), + static_cast(unpacked_frag_B[60]), static_cast(unpacked_frag_B[61]), + static_cast(unpacked_frag_B[62]), static_cast(unpacked_frag_B[63])); + } + + typename Dequantizer::FragmentLocalScale warp_frag_local_scale; + typename Dequantizer::FragmentCodeScale warp_frag_code_scale; + typename Dequantizer::FragmentCodeZp warp_frag_code_zp; + typename Dequantizer::FragmentSuperScale warp_frag_super_scale; + typename Dequantizer::FragmentOutOperand warp_frag_out; + + CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load"); + warp_dequantizer_.load(warp_frag_local_scale, + warp_frag_code_scale, + warp_frag_code_zp, + warp_frag_super_scale); + + CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant"); + warp_dequantizer_.dequantize(warp_frag_out, + pipe_state.warp_loaded_frag_B_[0], + warp_frag_local_scale, + warp_frag_code_scale, + warp_frag_code_zp, + warp_frag_super_scale); + +#if 0 // Transform, if necessary, the first warp-tile's shared memory fragments warp_mma_.transform( pipe_state.warp_transformed_frag_A_[0], @@ -715,10 +857,10 @@ class Wint2xMmaMultistage : accum, iterator_A, iterator_B, - tile_dequanter_B, gemm_k_iterations, stage); stage += 1; + break; } if (Detail::kStagedAccumulation) { @@ -730,6 +872,7 @@ class Wint2xMmaMultistage : cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); + #endif } /// Prepares the class for another prologue. @@ -761,14 +904,12 @@ class Wint2xMmaMultistage : else { this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - //this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); } smem_read_stage_idx_ = smem_write_stage_idx_; } /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. - template CUTLASS_DEVICE void operator()( ///< problem size of GEMM @@ -779,13 +920,13 @@ class Wint2xMmaMultistage : IteratorA iterator_A, ///< iterator over B operand in global memory IteratorB iterator_B, - ///< pre-load and dequantize B to shared memory - TileDequanterB tile_dequanter_B, ///< initial value of accumulator FragmentC const &src_accum) { + ptr_B_ = reinterpret_cast(iterator_B.get_origin_pointer()); + // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + prologue(iterator_A, iterator_B, gemm_k_iterations); // Wait until we have at least one completed global fetch stage gmem_wait(); @@ -794,7 +935,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h deleted file mode 100644 index cec6bcea03..0000000000 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "cutlass/gemm_coord.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -template -struct TileDequanter { - using WeightQuantTraits = WintQuantTraits; - using MmaElementT = typename WeightQuantTraits::MmaWeightType; - using QuantArguments = typename WeightQuantTraits::Arguments; - - using UnzipAndDequantFunctor = - UnzipAndDequantFunctor; - - static constexpr bool kUseSharedMemory = true; - - static constexpr int kRows = Rows; - static constexpr int kColumns = Columns; - static constexpr int kStages = Stages; - - MmaElementT *out_smem_ptr{nullptr}; - - char *pointer{nullptr}; - int64_t ldm{0}; - cutlass::MatrixCoord tb_offset; - cutlass::MatrixCoord extent; - - ScaleElementT *super_scale_ptr{nullptr}; - cutlass::MatrixCoord tb_offset_scale; - - QuantArguments quant_args; - - int64_t block_start_rows[kStages]; - bool need_preload{true}; - UnzipAndDequantFunctor unzip_functor; - - CUTLASS_DEVICE - TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, - const cutlass::MatrixCoord &extent, - const cutlass::MatrixCoord &tb_offset, - ScaleElementT *super_scale_ptr, - const cutlass::MatrixCoord &tb_offset_scale, - const QuantArguments &quant_args) - : out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent), - tb_offset(tb_offset), super_scale_ptr(super_scale_ptr), - tb_offset_scale(tb_offset_scale), quant_args(quant_args) {} - - CUTLASS_DEVICE - MmaElementT *GetOutPtr() { return out_smem_ptr; } - - CUTLASS_DEVICE - void AddTileOffset(const cutlass::MatrixCoord &tile_offset) { - tb_offset.row() += tile_offset.row() * kRows; - tb_offset.column() += tile_offset.column() * kColumns; - tb_offset_scale.column() += tile_offset.column() * kColumns; - } - - CUTLASS_DEVICE - void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row()); - if (tb_offset.row() >= extent.row() || - tb_offset.column() >= extent.column()) { - return; - } - - block_start_rows[stage % kStages] = tb_offset.row(); - - using ZippedT = typename WeightQuantTraits::WeightType; - ZippedT *in_ptr = reinterpret_cast(pointer) + zipped_row * ldm + - tb_offset.column(); - ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column(); - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - const uint8_t *local_scale_ptr = quant_args.local_scale_ptr + - (tb_offset.row() / 128) * ldm + - tb_offset_scale.column(); - const float *code_scale_ptr = - quant_args.code_scale_ptr + tb_offset_scale.column(); - const float *code_zp_ptr = - quant_args.code_zp_ptr + tb_offset_scale.column(); - - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr, - scale_ptr, &args, ldm, need_preload); - need_preload = false; - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } - - CUTLASS_DEVICE - void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int64_t block_start_row = block_start_rows[stage % kStages]; - if (block_start_row >= extent.row()) { - return; - } - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row); - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 350b247de2..af4298df5e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -41,12 +41,9 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +78,7 @@ struct DefaultMmaTensorOp::value; // Shape for loading the narrow data type from shared memory diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7c5088894b..ad1ca710e2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -295,6 +295,11 @@ class MmaTensorOpComputeBWithF16 assert(0); #endif } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h new file mode 100644 index 0000000000..4d05740d53 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -0,0 +1,696 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include +#include "cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_ = WeightOnlyQuantOp::UNDEFINED, + /// + typename Enable = void> +class MmaTensorOpWin2xDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> + +class MmaTensorOpWin2xDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + bfloat16_t, + layout::ColumnMajor, + 32, + QuantOp_, + typename platform::enable_if= + 70>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementWeight = uint2b_t; + + /// Type of the scales + using ElementUnzipWeight = uint8_t; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Type of the scales + using ScaleComputeT = float; + + static constexpr int unzip_len = 4; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + using FragmentWeightOperand = + Array; + using FragmentOutOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentLocalScale = Array; + using FragmentCodeScale = Array; + using FragmentCodeZp = Array; + using FragmentSuperScale = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::ColumnMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = cutlass::TensorRef; + using TensorCodeRef = cutlass::TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, + TensorCodeRef smem_code_scale, + TensorCodeRef smem_code_zp, + TensorRef smem_super_scale, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_local_scale_ = smem_local_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + } + } + + // CUTLASS_DEVICE + // MmaTensorOpWin2xDequantizer() { + // pointer_local_scale_ = nullptr; + // pointer_code_scale_ = nullptr; + // pointer_code_zp_ = nullptr; + // if constexpr (hasZero(QuantOp)) { + // pointer_super_scale_ = nullptr; + // } + // } + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer() { + // Create fake pointer using a shared dummy buffer + CUTLASS_TRACE_DEVICE(" warp dequant aaa"); + + extern __shared__ char cutlass_fake_dequant_smem[]; + + // Memory layout (manual alignment): + // ElementScale (half or bf16): 2 bytes + // ScaleComputeT (float): 4 bytes + + pointer_local_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem); + pointer_code_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem + 64); + pointer_code_zp_ = + reinterpret_cast(cutlass_fake_dequant_smem + 128); + + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = reinterpret_cast( + cutlass_fake_dequant_smem + 192); + } + } + + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_TRACE_DEVICE(" warp dequant load"); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * + // InstructionShape::kN]; code_scale_frag[mma_n_iter] = + // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * + // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) + // { + // super_scale_frag[mma_n_iter] = + // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; + // } + // } + } + + CUTLASS_DEVICE + void dequantize(FragmentOutOperand& out_frag, + FragmentDequantizedOperand& operand_frag, + FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + CUTLASS_TRACE_DEVICE(" dequantize if def"); + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kPackNum = 4; + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kLocalScaleMask = 0xF; + static constexpr int32_t kBZP = 32; + + // using _MmaOperandB = typename ArchMmaOperator::FragmentB; + // using ExpandedMmaOperandB = Array; + // static_assert(ExpandedMmaOperandB::kElements * + // MmaOperator::MmaIterations::kColumn + // == FragmentDequantizedOperand::kElements, + // ""); + + // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // MmaOperator::MmaIterations::kColumn); + + // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" + // FragmentDequantizedOperand::kElements = %d ", + // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" + // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); + + // FragmentWeightOperand + CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", + FragmentWeightOperand::kElements); + // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", + // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight + // size = %d ", sizeof(ElementWeight)); + static_assert(std::is_same::value, + "B 是 uint8 量化类型"); + FragmentWeightOperand* weight_ptr = + reinterpret_cast(&operand_frag); + FragmentLocalScale* local_scale_ptr = + reinterpret_cast(&local_scale_frag); + FragmentCodeScale* code_scale_ptr = + reinterpret_cast(&code_scale_frag); + FragmentCodeZp* code_zp_ptr = + reinterpret_cast(&code_zp_frag); + FragmentSuperScale* super_scale_ptr = + reinterpret_cast(&super_scale_frag); + + ScaleComputeT code_scale = + static_cast(code_scale_ptr[0][0]); + ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); + ScaleComputeT super_scale = + static_cast(super_scale_ptr[0][0]); + int32_t local_scale = static_cast(local_scale_ptr[0][0]); + int32_t const shift_bits[4] = {9, 6, 3, 0}; + + ScaleComputeT zipped_value[16]; +#pragma unroll + for (int i = 0; i < 16; ++i) { + zipped_value[i] = static_cast(weight_ptr[0][i]); + } + + int local_scale_shift = 4; + int32_t shifted_local_scale = + (local_scale >> local_scale_shift) & kLocalScaleMask; + ScaleComputeT scale = + static_cast(shifted_local_scale) * super_scale; + +#pragma unroll + for (int i = 0; i < 16; ++i) { + int32_t decode_value = static_cast( + floor(zipped_value[i] * code_scale + code_zp + + static_cast(0.5))); + + int col = i * 4; + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { + int32_t shift_bit = shift_bits[shift_bit_id]; + int32_t shifted_value = + (decode_value >> shift_bit) & kWeightMask; + + ScaleComputeT value = + static_cast(shifted_value - kBZP); + out_frag[col + shift_bit_id] = + static_cast(scale * value); + } + } + + CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", + kColsPerMmaPerThread); + CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", + MmaOperator::MmaIterations::kColumn); + + // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 + // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = + // reinterpret_cast(&operand_frag); + + // printf("threadidx.x = %d\n", threadIdx.x); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + // __nv_bfloat162 scalex2 = + // __bfloat162bfloat162(scale_ptr[mma_n_iter]); + // __nv_bfloat162* operand_bf16x2_ptr = + // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + // CUTLASS_PRAGMA_UNROLL + // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + // { + // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], + // scalex2); + // } + // } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on + // older arch, scale conversion should happen before scales are stored + // to shared memory and we should use the fp16 dequantizer. This will + // avoid numerous conversion instructions in GEMM main loop. + CUTLASS_TRACE_DEVICE(" dequantize else def"); + // arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_local_scale_ += offset; + pointer_code_scale_ += offset; + pointer_code_zp_ += offset; + pointer_super_scale_ += offset; + } + + private: + ElementScale const* pointer_local_scale_; + ScaleComputeT const* pointer_code_scale_; + ScaleComputeT const* pointer_code_zp_; + ElementScale const* pointer_super_scale_; + + ElementScale const* pointer_out_; +}; + +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpWin2xDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::ColumnMajor, + 32, + QuantOp_, + typename platform::enable_if= + 70>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementWeight = uint2b_t; + + /// Type of the scales + using ElementUnzipWeight = uint8_t; + + /// Type of the scales + using ElementScale = half_t; + + /// Type of the scales + using ScaleComputeT = float; + + static constexpr int unzip_len = 4; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + using FragmentWeightOperand = + Array; + using FragmentOutOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentLocalScale = Array; + using FragmentCodeScale = Array; + using FragmentCodeZp = Array; + using FragmentSuperScale = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::ColumnMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = cutlass::TensorRef; + using TensorCodeRef = cutlass::TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(TensorRef smem_local_scale, + TensorCodeRef smem_code_scale, + TensorCodeRef smem_code_zp, + TensorRef smem_super_scale, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_local_scale_ = smem_local_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + } + } + + // CUTLASS_DEVICE + // MmaTensorOpWin2xDequantizer() { + // pointer_local_scale_ = nullptr; + // pointer_code_scale_ = nullptr; + // pointer_code_zp_ = nullptr; + // if constexpr (hasZero(QuantOp)) { + // pointer_super_scale_ = nullptr; + // } + // } + + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer() { + // Create fake pointer using a shared dummy buffer + CUTLASS_TRACE_DEVICE(" warp dequant aaa"); + + extern __shared__ char cutlass_fake_dequant_smem[]; + + // Memory layout (manual alignment): + // ElementScale (half or bf16): 2 bytes + // ScaleComputeT (float): 4 bytes + + pointer_local_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem); + pointer_code_scale_ = + reinterpret_cast(cutlass_fake_dequant_smem + 64); + pointer_code_zp_ = + reinterpret_cast(cutlass_fake_dequant_smem + 128); + + if constexpr (hasZero(QuantOp)) { + pointer_super_scale_ = reinterpret_cast( + cutlass_fake_dequant_smem + 192); + } + } + + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_TRACE_DEVICE(" warp dequant load"); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * + // InstructionShape::kN]; code_scale_frag[mma_n_iter] = + // pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + // code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * + // InstructionShape::kN]; if constexpr (hasZero(QuantOp)) + // { + // super_scale_frag[mma_n_iter] = + // pointer_super_scale_[mma_n_iter * InstructionShape::kN]; + // } + // } + } + + CUTLASS_DEVICE + void dequantize(FragmentOutOperand& out_frag, + FragmentDequantizedOperand& operand_frag, + FragmentLocalScale& local_scale_frag, + FragmentCodeScale& code_scale_frag, + FragmentCodeZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + CUTLASS_TRACE_DEVICE(" dequantize if def"); + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kPackNum = 4; + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kLocalScaleMask = 0xF; + static constexpr int32_t kBZP = 32; + + // using _MmaOperandB = typename ArchMmaOperator::FragmentB; + // using ExpandedMmaOperandB = Array; + // static_assert(ExpandedMmaOperandB::kElements * + // MmaOperator::MmaIterations::kColumn + // == FragmentDequantizedOperand::kElements, + // ""); + + // CUTLASS_TRACE_DEVICE(" MmaIterations krow = %d, kcol = %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // MmaOperator::MmaIterations::kColumn); + + // CUTLASS_TRACE_DEVICE(" kExpansionFactor = %d / %d", + // MmaOperator::IteratorB::InstructionShape::kRow, + // InstructionShape::kK); CUTLASS_TRACE_DEVICE(" + // FragmentDequantizedOperand::kElements = %d ", + // FragmentDequantizedOperand::kElements); CUTLASS_TRACE_DEVICE(" + // _MmaOperandB::kElements = %d ", _MmaOperandB::kElements); + + // FragmentWeightOperand + CUTLASS_TRACE_DEVICE(" FragmentWeightOperand elem = %d ", + FragmentWeightOperand::kElements); + // CUTLASS_TRACE_DEVICE(" ElementUnzipWeight size = %d ", + // sizeof(ElementUnzipWeight)); CUTLASS_TRACE_DEVICE(" ElementWeight + // size = %d ", sizeof(ElementWeight)); + static_assert(std::is_same::value, + "B 是 uint8 量化类型"); + FragmentWeightOperand* weight_ptr = + reinterpret_cast(&operand_frag); + FragmentLocalScale* local_scale_ptr = + reinterpret_cast(&local_scale_frag); + FragmentCodeScale* code_scale_ptr = + reinterpret_cast(&code_scale_frag); + FragmentCodeZp* code_zp_ptr = + reinterpret_cast(&code_zp_frag); + FragmentSuperScale* super_scale_ptr = + reinterpret_cast(&super_scale_frag); + + ScaleComputeT code_scale = + static_cast(code_scale_ptr[0][0]); + ScaleComputeT code_zp = static_cast(code_zp_ptr[0][0]); + ScaleComputeT super_scale = + static_cast(super_scale_ptr[0][0]); + int32_t local_scale = static_cast(local_scale_ptr[0][0]); + int32_t const shift_bits[4] = {9, 6, 3, 0}; + + ScaleComputeT zipped_value[16]; +#pragma unroll + for (int i = 0; i < 16; ++i) { + zipped_value[i] = static_cast(weight_ptr[0][i]); + } + + int local_scale_shift = 4; + int32_t shifted_local_scale = + (local_scale >> local_scale_shift) & kLocalScaleMask; + ScaleComputeT scale = + static_cast(shifted_local_scale) * super_scale; + +#pragma unroll + for (int i = 0; i < 16; ++i) { + int32_t decode_value = static_cast( + floor(zipped_value[i] * code_scale + code_zp + + static_cast(0.5))); + + int col = i * 4; + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { + int32_t shift_bit = shift_bits[shift_bit_id]; + int32_t shifted_value = + (decode_value >> shift_bit) & kWeightMask; + + ScaleComputeT value = + static_cast(shifted_value - kBZP); + out_frag[col + shift_bit_id] = + static_cast(scale * value); + } + } + + CUTLASS_TRACE_DEVICE(" kColsPerMmaPerThread = %d ", + kColsPerMmaPerThread); + CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations::kColumn = %d ", + MmaOperator::MmaIterations::kColumn); + + // // __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 + // const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = + // reinterpret_cast(&operand_frag); + + // printf("threadidx.x = %d\n", threadIdx.x); + // CUTLASS_PRAGMA_UNROLL + // for (int mma_n_iter = 0; mma_n_iter < + // MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + // { + // static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + // __nv_bfloat162 scalex2 = + // __bfloat162bfloat162(scale_ptr[mma_n_iter]); + // __nv_bfloat162* operand_bf16x2_ptr = + // reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + // CUTLASS_PRAGMA_UNROLL + // for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + // { + // operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], + // scalex2); + // } + // } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on + // older arch, scale conversion should happen before scales are stored + // to shared memory and we should use the fp16 dequantizer. This will + // avoid numerous conversion instructions in GEMM main loop. + CUTLASS_TRACE_DEVICE(" dequantize else def"); + // arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_local_scale_ += offset; + pointer_code_scale_ += offset; + pointer_code_zp_ += offset; + pointer_super_scale_ += offset; + } + + private: + ElementScale const* pointer_local_scale_; + ScaleComputeT const* pointer_code_scale_; + ScaleComputeT const* pointer_code_zp_; + ElementScale const* pointer_super_scale_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680e..1f5584862b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -39,18 +39,16 @@ #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" +#include "cutlass/trace.h" -namespace cutlass -{ +namespace cutlass { // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; +struct FastInterleavedAndBiasedNumericArrayConverter; template <> struct FastInterleavedAndBiasedNumericArrayConverter @@ -440,6 +438,106 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = T; + + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kBZP = 32; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + uint8_t const* in_ptr = reinterpret_cast(&source); + + ScaleComputeT code_scale = static_cast(1); + ScaleComputeT code_zp = static_cast(0); + ScaleComputeT floor_offset = static_cast(0.5); + + CUTLASS_TRACE_DEVICE(" source: [%d, %d, %d, %d]", + static_cast(in_ptr[0]), static_cast(in_ptr[1]), + static_cast(in_ptr[2]), static_cast(in_ptr[3])); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + int32_t decode_value = + static_cast(floor(static_cast(in_ptr[i]) * code_scale + code_zp + floor_offset)); + + ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + + result[0] = static_cast(value_0); + result[1] = static_cast(value_1); + result[2] = static_cast(value_2); + result[3] = static_cast(value_3); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 356f305968..0d9aa62b3f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -43,7 +43,6 @@ #include "cutlass/trace.h" #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" #include "cutlass_extensions/tile_interleaved_layout.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -52,6 +51,15 @@ namespace cutlass { namespace gemm { namespace kernel { +template std::string GetCutlassLayoutString() { + if (std::is_same::value) { + return "cutlass::layout::RowMajor"; + } else if (std::is_same::value) { + return "cutlass::layout::ColumnMajor"; + } + return "unknown"; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // This section exists to that we can use the same kernel code for regular gemm // and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma @@ -282,6 +290,27 @@ struct MoeFCGemm { platform::is_same::value) { assert(weight_scales); } + + CUTLASS_TRACE_HOST("[Arguments] problem_count: " << problem_count << ", threadblock_count: " << threadblock_count << ", gemm_n: " << gemm_n << ", gemm_k: " << gemm_k); + CUTLASS_TRACE_HOST("[Arguments] ptr_A: " << static_cast(ptr_A)); + CUTLASS_TRACE_HOST("[Arguments] ptr_B: " << static_cast(ptr_B)); + CUTLASS_TRACE_HOST("[Arguments] ptr_C: " << static_cast(ptr_C)); + CUTLASS_TRACE_HOST("[Arguments] ptr_D: " << static_cast(ptr_D)); + CUTLASS_TRACE_HOST("[Arguments] weight_scales: " << static_cast(weight_scales)); + CUTLASS_TRACE_HOST("[Arguments] total_rows_before_expert: " << static_cast(total_rows_before_expert)); + CUTLASS_TRACE_HOST("[Arguments] local_scale: " << static_cast(local_scale)); + CUTLASS_TRACE_HOST("[Arguments] code_scale: " << static_cast(code_scale)); + CUTLASS_TRACE_HOST("[Arguments] code_zp: " << static_cast(code_zp)); + CUTLASS_TRACE_HOST("[Arguments] quant_method: " << static_cast(quant_method)); + CUTLASS_TRACE_HOST("[Arguments] LayoutA: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] LayoutB: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] LayoutC: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorA::AccessType::kElements:" << Mma::IteratorA::AccessType::kElements); + CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorB::AccessType::kElements:" << Mma::IteratorB::AccessType::kElements); + CUTLASS_TRACE_HOST("[Arguments] SharedStorage Information:"); + CUTLASS_TRACE_HOST(" - ProblemVisitor::SharedStorage: " << sizeof(typename ProblemVisitor::SharedStorage) << " bytes"); + CUTLASS_TRACE_HOST(" - Mma::SharedStorage: " << sizeof(typename Mma::SharedStorage) << " bytes"); + CUTLASS_TRACE_HOST(" - Epilogue::SharedStorage: " << sizeof(typename Epilogue::SharedStorage) << " bytes"); } }; @@ -814,9 +843,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 1, "B must be row major/col major OR col major interleaved."); - // LayoutB should be RowMajor - using TileDequanterB = cutlass::gemm::threadblock::TileDequanter; - // // Problem visitor. // @@ -835,6 +861,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm 2) { + break; + } + + CUTLASS_TRACE_DEVICE(" problem_idx: %d, cta_idx: %d, problem_size: {%d, %d, %d}", + problem_idx, cta_idx, static_cast(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); // threadblock_offset of C @@ -879,29 +912,17 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::value ? gemm_n : gemm_k * kInterleave; - typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; // the begin threadblock_offset of B, which holds the same column id with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; - MmaElementB* smem_unzip_B_ptr = nullptr; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); - } - QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n); - TileDequanterB tile_dequanter_B(smem_unzip_B_ptr, - byte_ptr_B, - ldm_B, - extent_B, - tb_offset_B, - weight_scale_ptr, - tb_offset_scale, - quant_args); - MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr(); + CUTLASS_TRACE_DEVICE(" ldm_B: %d, tb_offset_B: {%d, %d}, extent_B: {%d, %d}", + static_cast(ldm_B), tb_offset_B.row(), tb_offset_B.column(), extent_B.row(), extent_B.column()); + + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); // Compute position within threadblock int thread_idx = threadIdx.x; @@ -914,11 +935,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::Op; + CUTLASS_TRACE_HOST("Stages: " << Stages); + // Finally, set up the kernel. using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< ElementType, @@ -187,6 +189,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, "GroupedGEMM kernel"); } const int threadblock_count = multi_processor_count * occupancy; + CUTLASS_TRACE_HOST("kernel_occupancy: " << kernel_occupancy << ", occupancy: " << occupancy << ", threadblock_count: " << threadblock_count << ", multi_processor_count: " << multi_processor_count); typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); @@ -205,7 +208,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), + reinterpret_cast(B), reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), @@ -443,6 +446,7 @@ void dispatch_gemm_config(const T* A, #define dispatch_gemm_config_macro(AA, BB, CC, DD, EE, FF) \ case CutlassTileConfig:: \ CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \ + CUTLASS_TRACE_HOST("ThreadblockShape<" << AA << "," << BB << "," << CC << ">, WarpShape<" << DD << "," << EE << "," << FF << ">"); \ dispatch_gemm_config::value) { if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) { switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; @@ -563,16 +567,16 @@ void dispatch_moe_gemm_to_cutlass(const T* A, } else { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64); - dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64); - dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64); - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); - dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64); - dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64); - dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); - dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64); + //dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64); + //dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + //dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64); + //dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); + //dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; @@ -614,7 +618,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); + //dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); case CutlassTileConfig::Undefined: throw std::runtime_error( "[dispatch_moe_gemm_to_cutlass][SIMT] gemm config " @@ -711,8 +715,8 @@ void MoeGemmRunner::run_gemm( std::vector candidate_configs = get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true); - static constexpr int warm_time = 5; - static constexpr int test_time = 10; + static constexpr int warm_time = 0; + static constexpr int test_time = 1; auto& gemmConfigManager = GemmConfigManager::Instance(); constexpr GemmDataType dtype = getGemmDataType(); constexpr GemmDataType wdtype = getGemmDataType(); @@ -731,8 +735,10 @@ void MoeGemmRunner::run_gemm( std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows), gemmConfigManager.getMaxProfileM()); bool find_one = false; - size_t num_candidate_configs_size = candidate_configs.size(); - for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) { + size_t num_candidate_configs_size = 2;//candidate_configs.size(); + // for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) + { + size_t ii = 1; try { for (int i = 0; i < warm_time; i++) { dispatch_to_arch(A, @@ -776,7 +782,7 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); check_cuda_error(cudaEventDestroy(start)); check_cuda_error(cudaEventDestroy(stop)); - //std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; + std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; if (elapsed < best_time) { best_id = ii; best_time = elapsed; @@ -797,6 +803,7 @@ void MoeGemmRunner::run_gemm( } } +#if 0 dispatch_to_arch(A, B, weight_scales, @@ -810,6 +817,7 @@ void MoeGemmRunner::run_gemm( quant_args_B, chosen_config, stream); +#endif } template diff --git a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu index fb9d2e69fe..5c8bbd6797 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu @@ -20,6 +20,8 @@ #include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" +#define _GROUP_GEMM_ONLY 1 + template void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, @@ -65,7 +67,11 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, reinterpret_cast(ffn1_weight.data()), reinterpret_cast(ffn1_super_scale ? ffn1_super_scale->data() : nullptr), reinterpret_cast(ffn1_bias ? ffn1_bias->data() : nullptr), +#if _GROUP_GEMM_ONLY + reinterpret_cast(ffn_out.data()), +#else reinterpret_cast(fc1_out.data()), +#endif const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, actual_total_rows, @@ -76,9 +82,12 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, "none", stream); +#if _GROUP_GEMM_ONLY + // do nothing +#else paddle::Tensor act_out; if (used_in_ep_low_latency) { - act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); + //act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); } else { act_out = paddle::experimental::swiglu(fc1_out, nullptr); } @@ -96,6 +105,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, num_experts, ffn2_quant_args, stream); +#endif } template @@ -198,7 +208,14 @@ paddle::Tensor MoeExpertFFNWint2Func( const bool used_in_ep_low_latency) { const auto dtype = permute_input.dtype(); +#if _GROUP_GEMM_ONLY + auto place = permute_input.place(); + int64_t expanded_active_expert_rows = permute_input.dims()[0]; + int64_t inter_size = ffn1_scale.get().dims()[1]; + auto ffn_out = GetEmptyTensor({expanded_active_expert_rows, inter_size}, dtype, place); +#else auto ffn_out = paddle::empty_like(permute_input, dtype); +#endif switch (dtype) { case paddle::DataType::BFLOAT16: @@ -289,7 +306,14 @@ std::vector> MoeExpertFFNWint2InferShape( const paddle::optional>& ffn2_code_zp_shape, const bool used_in_ep_low_latency) { +#if _GROUP_GEMM_ONLY + int64_t expanded_active_expert_rows = permute_input_shape[0]; + int64_t inter_size = ffn1_scale_shape.get()[1]; + std::cout << "expanded_active_expert_rows: " << expanded_active_expert_rows << ", inter_size: " << inter_size << std::endl; + return {std::vector{expanded_active_expert_rows, inter_size}}; +#else return {permute_input_shape}; +#endif } std::vector MoeExpertFFNWint2InferDtype( @@ -356,7 +380,7 @@ std::vector MoeExpertFFNWint2InferDtype( * Note: * - Low latency mode uses specialized grouped SwiGLU implementation */ -PD_BUILD_STATIC_OP(moe_expert_ffn_wint2) +PD_BUILD_OP(moe_expert_ffn_wint2) .Inputs({"permute_input", "tokens_expert_prefix_sum", "ffn1_weight",