diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 40f128b7a0..8f61c6d9c4 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -133,10 +133,18 @@ template template struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; // 64 + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8 + +public: + // using Layout = layout::ColumnMajor; + // static constexpr int ElementsPerAccess = 16; // at least 4-bytes + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; // 64 + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index bc395d04db..b50d66380e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -18,14 +18,12 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -378,38 +376,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -441,38 +424,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 5d2c311704..300261c3f0 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -19,7 +19,7 @@ #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" namespace cutlass { namespace gemm { @@ -379,38 +379,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -442,38 +427,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h new file mode 100644 index 0000000000..4209c3029a --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -0,0 +1,245 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultQuantParamsIterators { +private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize; + static constexpr int kColumns = ThreadblockShape::kN; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, kAlignment>; + +public: + using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< + MatrixShape, ElementT, layout::RowMajor, 0, + IteratorThreadMap, kAlignment>; + using SmemIterator = Iterator; +}; + +template +struct DefaultQuantParamsIterators { +private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize); + static constexpr int kColumns = + (GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, kAlignment>; + +public: + using AccessType = cutlass::Array; + using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator< + MatrixShape, uint4b_t, layout::RowMajor, + 0, IteratorThreadMap, AccessType>; + + using SmemIterator = Iterator; +}; + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +struct DefaultWint2xMma; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DefaultWint2xMma +{ +public: + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Element B must be uint2b_t"); + + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + using ElementSuperScale = ElementA; + using ElementLocalScale = uint4b_t; + using ElementCodeScaleZp = float; + + static constexpr int kGroupSize = 64; + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved"); + static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); + + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; + static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); + + using IteratorShapeB = MatrixShape< + MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>; + using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + ThreadMapB::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB, + AccessTypeB>; + +private: + // Define iterators over tiles from extra quant params for B operand + using IteratorSuperScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementSuperScale, -1>::Iterator; + using SmemIteratorSuperScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementSuperScale, -1>::SmemIterator; + + using IteratorLocalScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator; + using SmemIteratorLocalScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator; + + using IteratorCodeScaleZp = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + +public: + using QuantParamsAccessor = Wint2ParamsAccessor< + ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale, + IteratorLocalScale, SmemIteratorLocalScale, + IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< + typename MmaCore::Shape, + IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, + ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + kStages, QuantParamsAccessor, SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 6dd55b647a..4b7d3ac06e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -63,8 +63,8 @@ template < typename Policy_, /// Number of stages, int Stages, - /// Used for partial specialization - typename Enable = bool> + /// Size of extra quantized params + typename QuantParamsShape> class Wint2xMmaBase { public: ///< Size of the Gemm problem - concept: gemm::GemmShape<> @@ -93,6 +93,14 @@ class Wint2xMmaBase { static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + /// Number of warp-level GEMM oeprations per load for B + static constexpr int kWarpGemmIterationsPerLoadForB = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); + + static constexpr int kWarpLoadIterationsForB = + kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; + /// Number of stages static int const kStages = Stages; @@ -104,8 +112,6 @@ class Wint2xMmaBase { using TensorRefB = TensorRef; - // using TensorRefZippedB = TensorRef; - static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " "GEMM operations."); @@ -130,20 +136,11 @@ class Wint2xMmaBase { Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = MatrixShape; - // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = - Shape::kK / 4 + (Shape::kK + 127) / 128; - - // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + - sizeof_bits::value / 8; - - using ZippedShapeB = MatrixShape; - - using NopaddingShapeB = MatrixShape; + /// Shape of all quant params in shared memory + using QuantParamsShapeB = QuantParamsShape; public: // @@ -156,12 +153,8 @@ class Wint2xMmaBase { /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for quanted B operand - AlignedBuffer operand_zipped_B; - - /// Buffer for unzip B operand - AlignedBuffer - operand_unzip_B; + /// Buffer for extra quant params of B operand + AlignedBuffer operand_quant_params_B; public: // @@ -191,14 +184,6 @@ class Wint2xMmaBase { TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - - CUTLASS_HOST_DEVICE - uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } - - CUTLASS_HOST_DEVICE - typename Operator::ElementB *operand_unzip_B_ptr() { - return operand_unzip_B.data(); - } }; protected: diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 38fdcf9fec..b3a9546014 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -45,7 +45,8 @@ #include "cutlass_extensions/arch/memory_copy_sm80.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,15 +87,15 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Accessor for extra quantized params + typename QuantParamsAccessor_, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class Wint2xMmaMultistage : - public Wint2xMmaBase { + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +class Wint2xMmaMultistage : + public Wint2xMmaBase { public: ///< Base class - using Base = Wint2xMmaBase; + using Base = Wint2xMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory @@ -107,8 +108,11 @@ class Wint2xMmaMultistage : using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; + /// Accessor for extra quantized params + using QuantParamsAccessor = QuantParamsAccessor_; + using QuantArguments = typename QuantParamsAccessor::Arguments; - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; using SmemIteratorA = SmemIteratorA_; using SmemIteratorB = SmemIteratorB_; @@ -129,6 +133,18 @@ class Wint2xMmaMultistage : /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; + //using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout; + using LayoutScale = layout::RowMajor; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using WarpDequantizer = + warp::MmaTensorOpWin2xDequantizer; + static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed"); + /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -174,18 +190,37 @@ class Wint2xMmaMultistage : using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale; + using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp; + using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale; + + /// channel-wise quant params + FragmentCodeScaleZp warp_frag_code_scale_; + FragmentCodeScaleZp warp_frag_code_zp_; + FragmentSuperScale warp_frag_super_scale_; + + /// group-wise quant params + FragmentLocalScale warp_frag_local_scale_; + /// Temporary accumulator to facilitate staged-accumulation FragmentC tmp_accum_; /// Pair of A fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentA warp_loaded_frag_A_[2]; - WarpTransformedFragmentA warp_transformed_frag_A_[2]; + WarpTransformedFragmentA warp_frag_A_[2]; /// Pair of B fragments used to overlap shared memory loads and math instructions WarpLoadedFragmentB warp_loaded_frag_B_[2]; - WarpTransformedFragmentB warp_transformed_frag_B_[2]; + WarpTransformedFragmentB warp_frag_B_; }; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool IsTileInterleaveLayout = + layout::IsColumnMajorTileInterleave::value; + static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); private: @@ -202,16 +237,23 @@ class Wint2xMmaMultistage : /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB smem_iterator_B_; + /// Accessor for extra quant params for B + QuantParamsAccessor quant_params_accessor_B_; + + // Wint2 unzip operator + WarpDequantizer warp_dequantizer_; + /// Shared memory write stage index int smem_write_stage_idx_; /// Shared memory read stage index int smem_read_stage_idx_; - uint8_t* column_wise_smem_ptr_B_; + ElementA* smem_ptr_A_; + ElementA* ptr_A_; - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; + uint8_t* smem_ptr_B_; + uint8_t* ptr_B_; public: @@ -226,10 +268,15 @@ class Wint2xMmaMultistage : int warp_idx, ///< ID of each thread within a warp int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), + ) : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx), + warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(), + quant_params_accessor_B_.local_scale_ref(), + quant_params_accessor_B_.code_scale_ref(), + quant_params_accessor_B_.code_zp_ref(), + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_write_stage_idx_(0), smem_read_stage_idx_(0) { @@ -245,16 +292,35 @@ class Wint2xMmaMultistage : int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + //CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d", + // Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave); + //CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d", + // Policy::kPartitionsK, Base::kWarpGemmIterations, + // Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k); + // Add per-warp offsets in units of warp-level tiles this->warp_tile_iterator_A_.add_tile_offset( {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); this->warp_tile_iterator_B_.add_tile_offset( {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; + //CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}", + // Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn); + //CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d", + // shared_storage.operand_A.data(), + // static_cast(Base::SharedStorage::ShapeA::kRow), static_cast(Base::SharedStorage::ShapeA::kColumn), + // static_cast(sizeof(shared_storage.operand_A)), + // static_cast(IteratorA::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorA::AccessType)), + // static_cast(Detail::AsyncCopyIterationsPerStageA), static_cast(IteratorA::kAccessesPerVector)); + //CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d", + // shared_storage.operand_B.data(), + // static_cast(Base::SharedStorage::ShapeB::kRow), static_cast(Base::SharedStorage::ShapeB::kColumn), + // static_cast(sizeof(shared_storage.operand_B)), + // static_cast(IteratorB::ThreadMap::kElementsPerAccess), static_cast(sizeof(typename IteratorB::AccessType)), + // static_cast(Detail::AsyncCopyIterationsPerStageB), static_cast(IteratorB::kAccessesPerVector)); + + smem_ptr_A_ = reinterpret_cast(shared_storage.operand_A.data()); + smem_ptr_B_ = reinterpret_cast(shared_storage.operand_B.data()); } /// Advance shared memory read-iterators to the next stage @@ -266,28 +332,22 @@ class Wint2xMmaMultistage : if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - // this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); smem_read_stage_idx_ = 0; } - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); } /// Advance global memory read-iterators and shared memory write-iterators to the stage - template CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) + void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); - //iterator_B.add_tile_offset({1, 0}); - tile_dequanter_B.AddTileOffset({1, 0}); + iterator_B.add_tile_offset({1, 0}); // Advance shared iterators smem_iterator_A_.add_tile_offset({0, 1}); - //smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_B_.add_tile_offset({1, 0}); // Increment shared memory write stage index ++smem_write_stage_idx_; @@ -295,7 +355,7 @@ class Wint2xMmaMultistage : if (smem_write_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - //smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); smem_write_stage_idx_ = 0; } } @@ -338,7 +398,6 @@ class Wint2xMmaMultistage : } } - template CUTLASS_DEVICE void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) { iterator_B.set_iteration_index(group_start_B * @@ -361,11 +420,20 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); +#if 0 + if (group_start_B == 0 && j == 0 && v == 0) { + CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d", + reinterpret_cast(dst_ptr), reinterpret_cast(gmem_ptr), + static_cast(Detail::kAccessesPerGroupB), static_cast(IteratorB::kAccessesPerVector), + static_cast(sizeof(typename IteratorB::Element))); + } +#endif + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( + cutlass::arch::cp_async_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); } else { - cutlass::arch::copy( + cutlass::arch::cp_async( dst_ptr + v, gmem_ptr, iterator_B.valid()); } @@ -379,7 +447,7 @@ class Wint2xMmaMultistage : } CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) { + void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A, int stage) { iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -401,6 +469,32 @@ class Wint2xMmaMultistage : int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); +#if 0 + if (v == 0) { + int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_A_); + int gmem_k = 8192; + int gmem_m = 16; + int gmem_row = gmem_offset / gmem_k; + int gmem_col = gmem_offset % gmem_k; + + int smem_offset = reinterpret_cast(dst_ptr) - reinterpret_cast(smem_ptr_A_); + int smem_k = Shape::kK; + int smem_m = Shape::kM; + int smem_row = smem_offset / smem_k; + int smem_col = smem_offset % smem_k; + + ElementA* gmem_element_A_ptr = reinterpret_cast(gmem_ptr); + CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%f, %f, %f, %f, %f, %f, %f, %f]; smem: {%d, %d};", + stage, reinterpret_cast(gmem_ptr), reinterpret_cast(dst_ptr), kSrcBytes, + gmem_m, gmem_k, gmem_row, gmem_col, + static_cast(gmem_element_A_ptr[0]), static_cast(gmem_element_A_ptr[1]), + static_cast(gmem_element_A_ptr[2]), static_cast(gmem_element_A_ptr[3]), + static_cast(gmem_element_A_ptr[4]), static_cast(gmem_element_A_ptr[5]), + static_cast(gmem_element_A_ptr[6]), static_cast(gmem_element_A_ptr[7]), + smem_row, smem_col); + } +#endif + cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); @@ -411,9 +505,9 @@ class Wint2xMmaMultistage : } } - template + template CUTLASS_DEVICE - void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B, int stage) { iterator_B.set_iteration_index(0); this->smem_iterator_B_.set_iteration_index(0); @@ -433,15 +527,42 @@ class Wint2xMmaMultistage : IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; +#if 0 + if (v == 0) { + int gmem_offset = reinterpret_cast(gmem_ptr) - reinterpret_cast(ptr_B_); + int gmem_k = 8192 * kInterleave / 4; + int gmem_n = 1792 / kInterleave; + int gmem_row = gmem_offset / gmem_k; + int gmem_col = gmem_offset % gmem_k; + + int smem_offset = reinterpret_cast(dst_ptr) - reinterpret_cast(smem_ptr_B_); + int smem_k = Shape::kK * kInterleave / 4; + int smem_n = Shape::kN / kInterleave; + int smem_row = smem_offset / smem_k; + int smem_col = smem_offset % smem_k; + + uint8_t* gmem_uint8_ptr = reinterpret_cast(gmem_ptr); + + CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};", + stage, reinterpret_cast(gmem_ptr), reinterpret_cast(dst_ptr), kSrcBytes, + gmem_n, gmem_k, gmem_row, gmem_col, + static_cast(gmem_uint8_ptr[0]), static_cast(gmem_uint8_ptr[1]), + static_cast(gmem_uint8_ptr[2]), static_cast(gmem_uint8_ptr[3]), + static_cast(gmem_uint8_ptr[4]), static_cast(gmem_uint8_ptr[5]), + static_cast(gmem_uint8_ptr[6]), static_cast(gmem_uint8_ptr[7]), + smem_row, smem_col); + } +#endif + if (InitStage) { - cutlass::arch::copy_zfill( + cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B.get(), iterator_B.valid()); } else { if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( + cutlass::arch::cp_async_zfill( dst_ptr + v, gmem_ptr, iterator_B.valid()); } else { - cutlass::arch::copy( + cutlass::arch::cp_async( dst_ptr + v, gmem_ptr, iterator_B.valid()); } } @@ -451,17 +572,15 @@ class Wint2xMmaMultistage : ++this->smem_iterator_B_; } - __syncthreads(); } /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - template CUTLASS_DEVICE void prologue( IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, + QuantArguments &mma_quant_args, ///< iterators for extra quant params for B int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages @@ -473,14 +592,21 @@ class Wint2xMmaMultistage : iterator_B.clear_mask(gemm_k_iterations == 0); // Async copy zipped B to shared memory. - copy_tiles_and_advance_per_stage_A(iterator_A); + copy_tiles_and_advance_per_stage_A(iterator_A, 0); // Async copy zipped B to shared memory. - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + copy_tiles_and_advance_per_stage_B(iterator_B, stage); + + // Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. + if (stage == 0) { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } else { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); @@ -542,84 +668,128 @@ class Wint2xMmaMultistage : } /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - template CUTLASS_DEVICE void mac_loop_iter( PipeState &pipe_state, ///< [in|out] loop-carried pipeline state FragmentC &accum, ///< [in|out] destination accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand + QuantArguments &mma_quant_args, ///< iterators for extra quant params for B int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining int stage) { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k); // Load the next warp-tile's A fragment from shared memory this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - // Unpack and dequant the first stage of B. - int unpack_stage = stage - Base::kStages + 2; - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, unpack_stage); + int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; + int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB; + + if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k_for_B + 1) % 2]); + ++this->warp_tile_iterator_B_; - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); } - // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; + // Execute the current warp-tile of MMA operations - // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + // CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements); + // CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements); + // CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation); + + // uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data()); + // CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", + // static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), + // static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), + // static_cast(reg_uint8_ptr[4]), static_cast(reg_uint8_ptr[5]), + // static_cast(reg_uint8_ptr[6]), static_cast(reg_uint8_ptr[7]), + // static_cast(reg_uint8_ptr[8]), static_cast(reg_uint8_ptr[9]), + // static_cast(reg_uint8_ptr[10]), static_cast(reg_uint8_ptr[11]), + // static_cast(reg_uint8_ptr[12]), static_cast(reg_uint8_ptr[13]), + // static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), + // sizeof_bits::value / 8); + + if (warp_k_compute_offset_B == 0) { + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_[warp_mma_k_for_B % 2], + pipe_state.warp_frag_B_, + (stage - Base::kStages + 2) * Shape::kK); } - // Execute the current warp-tile of MMA operations if (Detail::kStagedAccumulation) { + //CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B); warp_mma_( pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_, + // unpacked_frag_B, + pipe_state.tmp_accum_, + warp_k_compute_offset_B ); - if (warp_mma_k == 0) { - plus plus_accum; - accum = plus_accum(accum, pipe_state.tmp_accum_); - pipe_state.tmp_accum_.clear(); - } } else { + //CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B); warp_mma_( accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_, + // unpacked_frag_B, + accum, + warp_k_compute_offset_B ); +#if 0 + CUTLASS_TRACE_DEVICE(" pipe_state.warp_frag_B_=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(pipe_state.warp_frag_B_[0]), static_cast(pipe_state.warp_frag_B_[1]), + static_cast(pipe_state.warp_frag_B_[2]), static_cast(pipe_state.warp_frag_B_[3]), + static_cast(pipe_state.warp_frag_B_[4]), static_cast(pipe_state.warp_frag_B_[5]), + static_cast(pipe_state.warp_frag_B_[6]), static_cast(pipe_state.warp_frag_B_[7]), + static_cast(pipe_state.warp_frag_B_[8]), static_cast(pipe_state.warp_frag_B_[9]), + static_cast(pipe_state.warp_frag_B_[10]), static_cast(pipe_state.warp_frag_B_[11]), + static_cast(pipe_state.warp_frag_B_[12]), static_cast(pipe_state.warp_frag_B_[13]), + static_cast(pipe_state.warp_frag_B_[14]), static_cast(pipe_state.warp_frag_B_[15])); + + if (FragmentC::kElements == 16) { + CUTLASS_TRACE_DEVICE(" tile_C[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(accum[0]), static_cast(accum[1]), + static_cast(accum[2]), static_cast(accum[3]), + static_cast(accum[4]), static_cast(accum[5]), + static_cast(accum[6]), static_cast(accum[7]), + static_cast(accum[8]), static_cast(accum[9]), + static_cast(accum[10]), static_cast(accum[11]), + static_cast(accum[12]), static_cast(accum[13]), + static_cast(accum[14]), static_cast(accum[15])); + } + + // CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]", + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]), + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]), + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]), + // static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7])); +#endif } // Except for the last warp-tile, all warp-tiles issue their share of // global->shared fragment copies if (warp_mma_k < Base::kWarpGemmIterations - 1) { int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); - + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); if (warp_mma_k == 0) { - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); } } @@ -629,8 +799,10 @@ class Wint2xMmaMultistage : if (warp_mma_k + 2 == Base::kWarpGemmIterations) { // Performs the last warp-tile's share of global->shared fragment copies int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); @@ -639,43 +811,44 @@ class Wint2xMmaMultistage : gmem_wait(); // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); + advance_smem_read_stage(); + int byte_offset = quant_params_accessor_B_.advance_smem_read_stage(); + warp_dequantizer_.add_pointer_offset(byte_offset); // Disable global fetching when done with global fetch iterations --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); - } - - // The last warp-tile also converts the shared memory fragments used by - // the first warp-tile of the next iteration, if necessary (so we can - // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + iterator_B.clear_mask(gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); } } } /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. - template CUTLASS_DEVICE void gemm_iters( int gemm_k_iterations, ///< number of threadblock mainloop iterations FragmentC &accum, ///< [in|out] accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args) { - PipeState pipe_state; +#if 0 + CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentA::kElements=%d, %d bytes", + PipeState::WarpLoadedFragmentA::kElements, static_cast(sizeof_bits::value / 8)); + CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentB::kElements=%d, %d bytes", + PipeState::WarpLoadedFragmentB::kElements, static_cast(sizeof_bits::value / 8)); + CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentA::kElements=%d, %d bytes", + PipeState::WarpTransformedFragmentA::kElements, static_cast(sizeof_bits::value / 8)); + CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentB::kElements=%d, %d bytes", + PipeState::WarpTransformedFragmentB::kElements, static_cast(sizeof_bits::value / 8)); +#endif - // Unpack and dequant the first stage of B. - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); + PipeState pipe_state; // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); @@ -683,26 +856,120 @@ class Wint2xMmaMultistage : // Load first warp-tile's A fragment from shared memory this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]); ++this->warp_tile_iterator_A_; - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); - // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); ++this->warp_tile_iterator_B_; - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); +#if 0 + if (PipeState::WarpLoadedFragmentA::kElements == 8) { + ElementA* warp_frag_A_ptr = reinterpret_cast(pipe_state.warp_frag_A_[0].data()); + CUTLASS_TRACE_DEVICE(" warp_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes", + static_cast(warp_frag_A_ptr[0]), static_cast(warp_frag_A_ptr[1]), + static_cast(warp_frag_A_ptr[2]), static_cast(warp_frag_A_ptr[3]), + static_cast(warp_frag_A_ptr[4]), static_cast(warp_frag_A_ptr[5]), + static_cast(warp_frag_A_ptr[6]), static_cast(warp_frag_A_ptr[7]), + sizeof_bits::value / 8); + } +#endif +#if 0 + if (PipeState::WarpLoadedFragmentB::kElements == 64) { + uint8_t* reg_uint8_ptr = reinterpret_cast(pipe_state.warp_loaded_frag_B_.data()); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes", + static_cast(reg_uint8_ptr[0]), static_cast(reg_uint8_ptr[1]), + static_cast(reg_uint8_ptr[2]), static_cast(reg_uint8_ptr[3]), + static_cast(reg_uint8_ptr[4]), static_cast(reg_uint8_ptr[5]), + static_cast(reg_uint8_ptr[6]), static_cast(reg_uint8_ptr[7]), + static_cast(reg_uint8_ptr[8]), static_cast(reg_uint8_ptr[9]), + static_cast(reg_uint8_ptr[10]), static_cast(reg_uint8_ptr[11]), + static_cast(reg_uint8_ptr[12]), static_cast(reg_uint8_ptr[13]), + static_cast(reg_uint8_ptr[14]), static_cast(reg_uint8_ptr[15]), + sizeof_bits::value / 8); + } +#endif + + warp_dequantizer_.load(pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_); + + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); + +#if 0 + if (PipeState::FragmentLocalScale::kElements == 4) { + CUTLASS_TRACE_DEVICE(" FragmentLocalScale::kElements=%d, local_scale_frag[0:3]=[%d, %d, %d, %d], sizeof(FragmentLocalScale)=%d", + PipeState::FragmentLocalScale::kElements, + static_cast(pipe_state.warp_frag_local_scale_[0]), static_cast(pipe_state.warp_frag_local_scale_[1]), + static_cast(pipe_state.warp_frag_local_scale_[2]), static_cast(pipe_state.warp_frag_local_scale_[3]), + static_cast(sizeof(PipeState::FragmentLocalScale))); + } +#endif + +#if 0 + if (TransformBAfterLDS::result_type::kElements == 64) { + CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits::value / 8); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[0]), static_cast(unpacked_frag_B[1]), + static_cast(unpacked_frag_B[2]), static_cast(unpacked_frag_B[3]), + static_cast(unpacked_frag_B[4]), static_cast(unpacked_frag_B[5]), + static_cast(unpacked_frag_B[6]), static_cast(unpacked_frag_B[7]), + static_cast(unpacked_frag_B[8]), static_cast(unpacked_frag_B[9]), + static_cast(unpacked_frag_B[10]), static_cast(unpacked_frag_B[11]), + static_cast(unpacked_frag_B[12]), static_cast(unpacked_frag_B[13]), + static_cast(unpacked_frag_B[14]), static_cast(unpacked_frag_B[15])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[16]), static_cast(unpacked_frag_B[17]), + static_cast(unpacked_frag_B[18]), static_cast(unpacked_frag_B[19]), + static_cast(unpacked_frag_B[20]), static_cast(unpacked_frag_B[21]), + static_cast(unpacked_frag_B[22]), static_cast(unpacked_frag_B[23]), + static_cast(unpacked_frag_B[24]), static_cast(unpacked_frag_B[25]), + static_cast(unpacked_frag_B[26]), static_cast(unpacked_frag_B[27]), + static_cast(unpacked_frag_B[28]), static_cast(unpacked_frag_B[29]), + static_cast(unpacked_frag_B[30]), static_cast(unpacked_frag_B[31])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[32]), static_cast(unpacked_frag_B[33]), + static_cast(unpacked_frag_B[34]), static_cast(unpacked_frag_B[35]), + static_cast(unpacked_frag_B[36]), static_cast(unpacked_frag_B[37]), + static_cast(unpacked_frag_B[38]), static_cast(unpacked_frag_B[39]), + static_cast(unpacked_frag_B[40]), static_cast(unpacked_frag_B[41]), + static_cast(unpacked_frag_B[42]), static_cast(unpacked_frag_B[43]), + static_cast(unpacked_frag_B[44]), static_cast(unpacked_frag_B[45]), + static_cast(unpacked_frag_B[46]), static_cast(unpacked_frag_B[47])); + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_B[48]), static_cast(unpacked_frag_B[49]), + static_cast(unpacked_frag_B[50]), static_cast(unpacked_frag_B[51]), + static_cast(unpacked_frag_B[52]), static_cast(unpacked_frag_B[53]), + static_cast(unpacked_frag_B[54]), static_cast(unpacked_frag_B[55]), + static_cast(unpacked_frag_B[56]), static_cast(unpacked_frag_B[57]), + static_cast(unpacked_frag_B[58]), static_cast(unpacked_frag_B[59]), + static_cast(unpacked_frag_B[60]), static_cast(unpacked_frag_B[61]), + static_cast(unpacked_frag_B[62]), static_cast(unpacked_frag_B[63])); + } +#endif if (Detail::kStagedAccumulation) { pipe_state.tmp_accum_.clear(); + CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(pipe_state.tmp_accum_[0]), static_cast(pipe_state.tmp_accum_[1]), + static_cast(pipe_state.tmp_accum_[2]), static_cast(pipe_state.tmp_accum_[3]), + static_cast(pipe_state.tmp_accum_[4]), static_cast(pipe_state.tmp_accum_[5]), + static_cast(pipe_state.tmp_accum_[6]), static_cast(pipe_state.tmp_accum_[7]), + static_cast(pipe_state.tmp_accum_[8]), static_cast(pipe_state.tmp_accum_[9]), + static_cast(pipe_state.tmp_accum_[10]), static_cast(pipe_state.tmp_accum_[11]), + static_cast(pipe_state.tmp_accum_[12]), static_cast(pipe_state.tmp_accum_[13]), + static_cast(pipe_state.tmp_accum_[14]), static_cast(pipe_state.tmp_accum_[15])); + } else { + CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(accum[0]), static_cast(accum[1]), + static_cast(accum[2]), static_cast(accum[3]), + static_cast(accum[4]), static_cast(accum[5]), + static_cast(accum[6]), static_cast(accum[7]), + static_cast(accum[8]), static_cast(accum[9]), + static_cast(accum[10]), static_cast(accum[11]), + static_cast(accum[12]), static_cast(accum[13]), + static_cast(accum[14]), static_cast(accum[15])); } int stage = Base::kStages - 1; @@ -710,12 +977,15 @@ class Wint2xMmaMultistage : // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { + if (stage > Base::kStages + 1) { + //break; + } mac_loop_iter( pipe_state, accum, iterator_A, iterator_B, - tile_dequanter_B, + mma_quant_args, gemm_k_iterations, stage); stage += 1; @@ -761,14 +1031,12 @@ class Wint2xMmaMultistage : else { this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - //this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); } smem_read_stage_idx_ = smem_write_stage_idx_; } /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. - template CUTLASS_DEVICE void operator()( ///< problem size of GEMM @@ -779,13 +1047,16 @@ class Wint2xMmaMultistage : IteratorA iterator_A, ///< iterator over B operand in global memory IteratorB iterator_B, - ///< pre-load and dequantize B to shared memory - TileDequanterB tile_dequanter_B, + ///< iterators for extra quant params for B + QuantArguments mma_quant_args, ///< initial value of accumulator FragmentC const &src_accum) { + ptr_A_ = reinterpret_cast(iterator_A.get_origin_pointer()); + ptr_B_ = reinterpret_cast(iterator_B.get_origin_pointer()); + // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations); // Wait until we have at least one completed global fetch stage gmem_wait(); @@ -794,7 +1065,7 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h new file mode 100644 index 0000000000..3171ba94b2 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h @@ -0,0 +1,335 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/trace.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Original data type + typename T, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterators over super scales in global memory + typename IteratorSuperScale_, + /// Iterators over super scales in shared memory + typename SmemIteratorSuperScale_, + /// Iterators over local scales in global memory + typename IteratorLocalScale_, + /// Iterators over local scales in shared memory + typename SmemIteratorLocalScale_, + /// Iterators over code scales and zps in global memory + typename IteratorCodeScaleZp_, + /// Iterators over code scales and zps in shared memory + typename SmemIteratorCodeScaleZp_, + /// Number of stages, + int Stages_, + /// Group size for quantization + int GroupSize_> +class Wint2ParamsAccessor { +public: + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using ElementType = T; + using Shape = Shape_; + + using IteratorSuperScale = IteratorSuperScale_; + using SmemIteratorSuperScale = SmemIteratorSuperScale_; + + using IteratorLocalScale = IteratorLocalScale_; + using SmemIteratorLocalScale = SmemIteratorLocalScale_; + + using IteratorCodeScaleZp = IteratorCodeScaleZp_; + using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_; + + constexpr static int kStages = Stages_; + constexpr static int kGroupSize = GroupSize_; + + using ElementSuperScale = typename IteratorSuperScale::Element; + using LayoutSuperScale = typename IteratorSuperScale::Layout; + + /// local_scale uint4 and group-wise + using ElementLocalScale = typename IteratorLocalScale::Element; + using LayoutLocalScale = typename IteratorLocalScale::Layout; + static_assert(platform::is_same::value, + "local_scale's type must be uint4b_t."); + + using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element; + using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout; + + /// 2 uint4b_t values are stored in a single uint8_t + constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK; + constexpr static int kLocalScaleRows = IteratorLocalScale::Shape::kRow; + + using SmemElement = uint8_t; + constexpr static int kSmemRows = + sizeof(ElementLocalScale) * kLocalScaleRows * kStages + + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2; + constexpr static int kSmemColumns = Shape::kN; + + using QuantParamsShape = MatrixShape; + + constexpr static int kSuperScaleSmemOffset = 0; + constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale); + constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; + + struct Arguments { + IteratorSuperScale iterator_super_scale; + IteratorLocalScale iterator_local_scale; + IteratorCodeScaleZp iterator_code_scale; + IteratorCodeScaleZp iterator_code_zp; + + int local_scale_pointer_offset; + + CUTLASS_DEVICE + Arguments(IteratorSuperScale iterator_super_scale, + IteratorLocalScale iterator_local_scale, + IteratorCodeScaleZp iterator_code_scale, + IteratorCodeScaleZp iterator_code_zp, + int local_scale_pointer_offset) + : iterator_super_scale(iterator_super_scale), + iterator_local_scale(iterator_local_scale), + iterator_code_scale(iterator_code_scale), + iterator_code_zp(iterator_code_zp), + local_scale_pointer_offset(local_scale_pointer_offset) {} + }; + +private: + // + // Data members + // + + /// Begin address of shared memory + uint8_t* smem_pointer_; + + /// Iterator to write threadblock-scoped tile of super scale operand to shared memory + SmemIteratorSuperScale smem_iterator_super_scale_; + /// Iterator to write threadblock-scoped tile of local scale operand to shared memory + SmemIteratorLocalScale smem_iterator_local_scale_; + /// Iterator to write threadblock-scoped tile of code scale operand to shared memory + SmemIteratorCodeScaleZp smem_iterator_code_scale_; + /// Iterator to write threadblock-scoped tile of code zp operand to shared memory + SmemIteratorCodeScaleZp smem_iterator_code_zp_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + CUTLASS_DEVICE + ElementSuperScale* get_super_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kSuperScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementLocalScale* get_local_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kLocalScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kCodeScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_zp_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kCodeZpSmemOffset); + } + +public: + /// Construct from tensor references + CUTLASS_DEVICE + Wint2ParamsAccessor( + ///< prointer of shared memory + uint8_t* smem_pointer, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : smem_pointer_(smem_pointer), + smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn), + get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx), + smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn), + get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx), + smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), + smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + //CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, kSmemRows=%d, kSmemColumns=%d, kLocalScaleRows=%d, kStagesPerLocalScaleLoad=%d", + // Shape::kM, Shape::kN, Shape::kK, kSmemRows, kSmemColumns, kLocalScaleRows, kStagesPerLocalScaleLoad); + //CUTLASS_TRACE_DEVICE(" IteratorSuperScale::Shape: {%d, %d}, kSuperScaleSmemOffset=%d, smem_ptr=%p", + // IteratorSuperScale::Shape::kRow, IteratorSuperScale::Shape::kColumn, kSuperScaleSmemOffset, get_super_scale_smem_ptr()); + //CUTLASS_TRACE_DEVICE(" IteratorLocalScale::Shape: {%d, %d}, kLocalScaleSmemOffset=%d, smem_ptr=%p", + // IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn, kLocalScaleSmemOffset, get_local_scale_smem_ptr()); + //CUTLASS_TRACE_DEVICE(" IteratorCodeScaleZp::Shape: {%d, %d}, kCodeScaleSmemOffset=%d, smem_ptr=%p", + // IteratorCodeScaleZp::Shape::kRow, IteratorCodeScaleZp::Shape::kColumn, kCodeScaleSmemOffset, get_code_scale_smem_ptr()); + //CUTLASS_TRACE_DEVICE(" IteratorCodeScaleZp::Shape: {%d, %d}, kCodeZpSmemOffset=%d, smem_ptr=%p", + // IteratorCodeScaleZp::Shape::kRow, IteratorCodeScaleZp::Shape::kColumn, kCodeZpSmemOffset, get_code_zp_smem_ptr()); + } + + CUTLASS_DEVICE + SuperTensorRef super_scale_ref() { + return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + LocalTensorRef local_scale_ref() { + return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_scale_ref() { + return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_zp_ref() { + return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + template + CUTLASS_DEVICE + void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) { + if constexpr (IsFirstStage) { + //CUTLASS_TRACE_DEVICE(" [stage=%d][SuperScale] Shape: {%d, %d}, Fragment::kElements=%d", + // stage, IteratorSuperScale::Shape::kRow, IteratorSuperScale::Shape::kColumn, IteratorSuperScale::Fragment::kElements); + //CUTLASS_TRACE_DEVICE(" [stage=%d][CodeScale] Shape: {%d, %d}, Fragment::kElements=%d", + // stage, IteratorCodeScaleZp::Shape::kRow, IteratorCodeScaleZp::Shape::kColumn, IteratorCodeScaleZp::Fragment::kElements); + + // Load channel-wise super_scale to shared memory, which only needs to be done once. + typename IteratorSuperScale::Fragment tb_frag_super_scale; + tb_frag_super_scale.clear(); + quant_args.iterator_super_scale.load(tb_frag_super_scale); + this->smem_iterator_super_scale_.store(tb_frag_super_scale); + + // Load channel-wise code_scale to shared memory, which only needs to be done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_scale; + tb_frag_code_scale.clear(); + quant_args.iterator_code_scale.load(tb_frag_code_scale); + this->smem_iterator_code_scale_.store(tb_frag_code_scale); + + // Load channel-wise code_zp to shared memory, which only needs to be done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_zp; + tb_frag_code_zp.clear(); + quant_args.iterator_code_zp.load(tb_frag_code_zp); + this->smem_iterator_code_zp_.store(tb_frag_code_zp); + } + + if ((stage % kStagesPerLocalScaleLoad) == 0) { + // Load group-wise local_scale to shared memory, which only needs to be done at each stage. + // Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages. + using AccessType = typename IteratorLocalScale::AccessType; + cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits::value == 128) + ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; + + quant_args.iterator_local_scale.set_iteration_index(0); + this->smem_iterator_local_scale_.set_iteration_index(0); + + // Async Copy for local_scale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) { + AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_local_scale_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) { + auto gmem_ptr = quant_args.iterator_local_scale.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorLocalScale::ThreadMap::kElementsPerAccess / + IteratorLocalScale::kAccessesPerVector / 8; + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid()); + } + ++quant_args.iterator_local_scale; + } + ++this->smem_iterator_local_scale_; + + //CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] Shape: {%d, %d}", + // stage, IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn); + } + } + + CUTLASS_DEVICE + void advance_smem_write_stage(Arguments &quant_args) { + if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + // Advance global iterators + quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset); + + // Advance shared iterators + int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(pointer_offset); + smem_write_stage_idx_ = 0; + } + //CUTLASS_TRACE_DEVICE(" smem_write_stage_idx_=%d", smem_write_stage_idx_); + } + + CUTLASS_DEVICE + int advance_smem_read_stage() { + int byte_offset = 0; + if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + byte_offset = kLocalScaleRows * kSmemColumns; + } + + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + smem_read_stage_idx_ = 0; + byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns; + } + //CUTLASS_TRACE_DEVICE(" smem_read_stage_idx_=%d, byte_offset=%d", smem_read_stage_idx_, byte_offset); + return byte_offset; + } + + CUTLASS_DEVICE + int clear_mask(Arguments &quant_args, bool cond) { + quant_args.iterator_local_scale.clear_mask(cond); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h deleted file mode 100644 index cec6bcea03..0000000000 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "cutlass/gemm_coord.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -template -struct TileDequanter { - using WeightQuantTraits = WintQuantTraits; - using MmaElementT = typename WeightQuantTraits::MmaWeightType; - using QuantArguments = typename WeightQuantTraits::Arguments; - - using UnzipAndDequantFunctor = - UnzipAndDequantFunctor; - - static constexpr bool kUseSharedMemory = true; - - static constexpr int kRows = Rows; - static constexpr int kColumns = Columns; - static constexpr int kStages = Stages; - - MmaElementT *out_smem_ptr{nullptr}; - - char *pointer{nullptr}; - int64_t ldm{0}; - cutlass::MatrixCoord tb_offset; - cutlass::MatrixCoord extent; - - ScaleElementT *super_scale_ptr{nullptr}; - cutlass::MatrixCoord tb_offset_scale; - - QuantArguments quant_args; - - int64_t block_start_rows[kStages]; - bool need_preload{true}; - UnzipAndDequantFunctor unzip_functor; - - CUTLASS_DEVICE - TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, - const cutlass::MatrixCoord &extent, - const cutlass::MatrixCoord &tb_offset, - ScaleElementT *super_scale_ptr, - const cutlass::MatrixCoord &tb_offset_scale, - const QuantArguments &quant_args) - : out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent), - tb_offset(tb_offset), super_scale_ptr(super_scale_ptr), - tb_offset_scale(tb_offset_scale), quant_args(quant_args) {} - - CUTLASS_DEVICE - MmaElementT *GetOutPtr() { return out_smem_ptr; } - - CUTLASS_DEVICE - void AddTileOffset(const cutlass::MatrixCoord &tile_offset) { - tb_offset.row() += tile_offset.row() * kRows; - tb_offset.column() += tile_offset.column() * kColumns; - tb_offset_scale.column() += tile_offset.column() * kColumns; - } - - CUTLASS_DEVICE - void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row()); - if (tb_offset.row() >= extent.row() || - tb_offset.column() >= extent.column()) { - return; - } - - block_start_rows[stage % kStages] = tb_offset.row(); - - using ZippedT = typename WeightQuantTraits::WeightType; - ZippedT *in_ptr = reinterpret_cast(pointer) + zipped_row * ldm + - tb_offset.column(); - ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column(); - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - const uint8_t *local_scale_ptr = quant_args.local_scale_ptr + - (tb_offset.row() / 128) * ldm + - tb_offset_scale.column(); - const float *code_scale_ptr = - quant_args.code_scale_ptr + tb_offset_scale.column(); - const float *code_zp_ptr = - quant_args.code_zp_ptr + tb_offset_scale.column(); - - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr, - scale_ptr, &args, ldm, need_preload); - need_preload = false; - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } - - CUTLASS_DEVICE - void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int64_t block_start_row = block_start_rows[stage % kStages]; - if (block_start_row >= extent.row()) { - return; - } - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row); - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 350b247de2..af4298df5e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -41,12 +41,9 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +78,7 @@ struct DefaultMmaTensorOp::value; // Shape for loading the narrow data type from shared memory diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7c5088894b..ad1ca710e2 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -295,6 +295,11 @@ class MmaTensorOpComputeBWithF16 assert(0); #endif } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h new file mode 100644 index 0000000000..fdf6b99947 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -0,0 +1,391 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename ElementOperand_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + typename Enable = void> +class MmaTensorOpWin2xDequantizer { + //static_assert(false, "Not Supported!"); +}; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// Data type of Scale elements + typename ElementOperand_> +class MmaTensorOpWin2xDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + ElementOperand_, + layout::RowMajor, + 32> + //typename platform::enable_if= 80 + // && platform::is_same::value>::type> +{ +public: + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of mma operand + using ElementOperand = ElementOperand_; + + /// Type of input + using ElementB = typename MmaOperator::FragmentB::Element; + static_assert(platform::is_same::value, "ElementB must be uint2b_t"); + + /// Type of internal compute + using ElementCompute = float; + + /// Type of the scales + using ElementLocalScale = uint4b_t; + using ElementSuperScale = ElementOperand; + using ElementCodeScaleZp = float; + + /// Fragment to hold B data before Mma + using FragmentInput = Array; + + /// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points. + using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementOperand, ElementB, MmaOperator::FragmentB::kElements>; + using FragmentUnpack = typename Uint2Converter::result_type; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + static constexpr int kElements = kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn; + + // 32 bits are loaded to register from shared memory by each thread + static constexpr int kMmaIterationsPerLoad = + 32 / (sizeof_bits::value * ArchMmaOperator::FragmentB::kElements); + + // use uint8_t to save 2 4-bits local scales + using FragmentLocalScale = Array; + using FragmentSuperScale = Array; + using FragmentCodeScaleZp = Array; + + /// Fragment to hold internal scales before Mma + using FragmentCompute = Array; + + /// Fragment of dequantized B + //using FragmentOutput = Array; + using FragmentOutput = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; + +private: + // + // Data members + // + + uint8_t* pointer_local_scale_; + ElementCodeScaleZp* pointer_code_scale_; + ElementCodeScaleZp* pointer_code_zp_; + ElementSuperScale* pointer_super_scale_; + + FragmentUnpack unpacked_frag_; + +public: + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale, + LocalTensorRef smem_local_scale, + CodeTensorRef smem_code_scale, + CodeTensorRef smem_code_zp, + int warp_idx_n, + int lane_idx) { + int warp_offset = warp_idx_n * Shape::kN; + int quad = lane_idx / 4; + int thread_offset = warp_offset + quad; + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + pointer_local_scale_ = reinterpret_cast(smem_local_scale.data()) + thread_offset; + } + + /// Channel-wise params, need to load just once + CUTLASS_DEVICE + void load(FragmentCodeScaleZp& code_scale_frag, + FragmentCodeScaleZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict + code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN]; + } + } + + /// Group-wise params, need to load multiple times + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag) { + //CUTLASS_TRACE_DEVICE(" pointer_local_scale_=%p", pointer_local_scale_); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict + } + } + + CUTLASS_DEVICE + void dequantize(const FragmentLocalScale& local_scale_frag, + const FragmentCodeScaleZp& code_scale_frag, + const FragmentCodeScaleZp& code_zp_frag, + const FragmentSuperScale& super_scale_frag, + const FragmentInput& input_frag, + FragmentOutput& output_frag, + int tb_offset_k) { + int stage = tb_offset_k / 64; + + //CUTLASS_TRACE_DEVICE(" FragmentInput::kElements=%d, %d bytes", + // FragmentInput::kElements, static_cast(sizeof_bits::value / 8)); + //CUTLASS_TRACE_DEVICE(" FragmentUnpack::kElements=%d, %d bytes", + // FragmentUnpack::kElements, static_cast(sizeof_bits::value / 8)); + //CUTLASS_TRACE_DEVICE(" FragmentOutput::kElements=%d, %d bytes", + // FragmentOutput::kElements, static_cast(sizeof_bits::value / 8)); + + //CUTLASS_TRACE_DEVICE(" MmaOperator::FragmentB::kElements=%d", MmaOperator::FragmentB::kElements); + //CUTLASS_TRACE_DEVICE(" MmaOperator::IteratorB::InstructionShape: %dx%d; InstructionShape: %dx%dx%d; ", + // MmaOperator::IteratorB::InstructionShape::kRow, MmaOperator::IteratorB::InstructionShape::kColumn, + // InstructionShape::kM, InstructionShape::kN, InstructionShape::kK); + //CUTLASS_TRACE_DEVICE(" MmaOperator::MmaIterations: kRow=%d, kColumn=%d", + // MmaOperator::MmaIterations::kRow, MmaOperator::MmaIterations::kColumn); + + unpacked_frag_ = Uint2Converter::convert(input_frag); + // DEBUG CODES + for (int i = 0; i < FragmentUnpack::kElements; ++i) { + unpacked_frag_[i] = static_cast(1); //static_cast((i / 16) * 8 + (threadIdx.x % 32) / 4); + } + +#if 0 + if (FragmentUnpack::kElements == 64) { + CUTLASS_TRACE_DEVICE(" unpacked_frag_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_[0]), static_cast(unpacked_frag_[1]), + static_cast(unpacked_frag_[2]), static_cast(unpacked_frag_[3]), + static_cast(unpacked_frag_[4]), static_cast(unpacked_frag_[5]), + static_cast(unpacked_frag_[6]), static_cast(unpacked_frag_[7]), + static_cast(unpacked_frag_[8]), static_cast(unpacked_frag_[9]), + static_cast(unpacked_frag_[10]), static_cast(unpacked_frag_[11]), + static_cast(unpacked_frag_[12]), static_cast(unpacked_frag_[13]), + static_cast(unpacked_frag_[14]), static_cast(unpacked_frag_[15])); + CUTLASS_TRACE_DEVICE(" unpacked_frag_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + static_cast(unpacked_frag_[16]), static_cast(unpacked_frag_[17]), + static_cast(unpacked_frag_[18]), static_cast(unpacked_frag_[19]), + static_cast(unpacked_frag_[20]), static_cast(unpacked_frag_[21]), + static_cast(unpacked_frag_[22]), static_cast(unpacked_frag_[23]), + static_cast(unpacked_frag_[24]), static_cast(unpacked_frag_[25]), + static_cast(unpacked_frag_[26]), static_cast(unpacked_frag_[27]), + static_cast(unpacked_frag_[28]), static_cast(unpacked_frag_[29]), + static_cast(unpacked_frag_[30]), static_cast(unpacked_frag_[31])); + } +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kLocalScaleMask = 0xF; + + // special for TileRows = 64 + int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4; + FragmentCompute scale_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentLocalScale::kElements; ++i) { + int32_t shifted_local_scale = + (static_cast(local_scale_frag[i]) >> local_scale_shift) & kLocalScaleMask; + scale_frag[i] = + static_cast(shifted_local_scale) * static_cast(super_scale_frag[i]); + } + +#if 0 + if (FragmentCompute::kElements == 4) { + CUTLASS_TRACE_DEVICE(" [stage=%d] tb_offset_k=%d, local_scale_shift=%d, scale_frag[0:3]=[%f, %f, %f, %f], sizeof(FragmentCompute)=%d bytes", + stage, tb_offset_k, local_scale_shift, + static_cast(scale_frag[0]), static_cast(scale_frag[1]), + static_cast(scale_frag[2]), static_cast(scale_frag[3]), + static_cast(sizeof(FragmentCompute))); + } +#endif + + //int offset = warp_mma_k * ArchMmaOperator::FragmentB::kElements; + int num_columns = 32 / sizeof_bits::value; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < num_columns; ++j) { + ElementCompute scaled_value = + static_cast(unpacked_frag_[mma_n_iter * num_columns + j]) * scale_frag[mma_n_iter]; + output_frag[mma_n_iter * num_columns + j] = static_cast(scaled_value); + } + } + + if (FragmentOutput::kElements == 64) { +#if 0 + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[0]), static_cast(output_frag[1]), + static_cast(output_frag[2]), static_cast(output_frag[3]), + static_cast(output_frag[4]), static_cast(output_frag[5]), + static_cast(output_frag[6]), static_cast(output_frag[7]), + static_cast(output_frag[8]), static_cast(output_frag[9]), + static_cast(output_frag[10]), static_cast(output_frag[11]), + static_cast(output_frag[12]), static_cast(output_frag[13]), + static_cast(output_frag[14]), static_cast(output_frag[15])); + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[16]), static_cast(output_frag[17]), + static_cast(output_frag[18]), static_cast(output_frag[19]), + static_cast(output_frag[20]), static_cast(output_frag[21]), + static_cast(output_frag[22]), static_cast(output_frag[23]), + static_cast(output_frag[24]), static_cast(output_frag[25]), + static_cast(output_frag[26]), static_cast(output_frag[27]), + static_cast(output_frag[28]), static_cast(output_frag[29]), + static_cast(output_frag[30]), static_cast(output_frag[31])); + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[32]), static_cast(output_frag[33]), + static_cast(output_frag[34]), static_cast(output_frag[35]), + static_cast(output_frag[36]), static_cast(output_frag[37]), + static_cast(output_frag[38]), static_cast(output_frag[39]), + static_cast(output_frag[40]), static_cast(output_frag[41]), + static_cast(output_frag[42]), static_cast(output_frag[43]), + static_cast(output_frag[44]), static_cast(output_frag[45]), + static_cast(output_frag[46]), static_cast(output_frag[47])); + CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]", + stage, + static_cast(output_frag[48]), static_cast(output_frag[49]), + static_cast(output_frag[50]), static_cast(output_frag[51]), + static_cast(output_frag[52]), static_cast(output_frag[53]), + static_cast(output_frag[54]), static_cast(output_frag[55]), + static_cast(output_frag[56]), static_cast(output_frag[57]), + static_cast(output_frag[58]), static_cast(output_frag[59]), + static_cast(output_frag[60]), static_cast(output_frag[61]), + static_cast(output_frag[62]), static_cast(output_frag[63])); +#endif + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on + // older arch, scale conversion should happen before scales are stored + // to shared memory and we should use the fp16 dequantizer. This will + // avoid numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + + const int fixed_values[64] = { + 0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57, + 2, 3, 10, 11, 18, 19, 26, 27, 34, 35, 42, 43, 50, 51, 58, 59, + 4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61, + 6, 7, 14, 15, 22, 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63 + }; + for (int i = 0; i < FragmentUnpack::kElements; ++i) { + output_frag[i] = static_cast(fixed_values[(i % 16) + (threadIdx.x % 4) * 16]); + } + } + + /// Add an offset to pointer in units of elements. + /// Only group-wise params needs. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + pointer_local_scale_ += offset; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680e..4e1e1781d6 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -39,18 +39,16 @@ #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" +#include "cutlass/trace.h" -namespace cutlass -{ +namespace cutlass { // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; +struct FastInterleavedAndBiasedNumericArrayConverter; template <> struct FastInterleavedAndBiasedNumericArrayConverter @@ -440,6 +438,137 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = T; + + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kBZP = 32; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + uint8_t const* in_ptr = reinterpret_cast(&source); + + ScaleComputeT code_scale = static_cast(1); + ScaleComputeT code_zp = static_cast(0); + ScaleComputeT floor_offset = static_cast(0.5); + +#if 0 + CUTLASS_TRACE_DEVICE(" source: [%d, %d, %d, %d]", + static_cast(in_ptr[0]), static_cast(in_ptr[1]), + static_cast(in_ptr[2]), static_cast(in_ptr[3])); +#endif + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + int32_t decode_value = + static_cast(floor(static_cast(in_ptr[i]) * code_scale + code_zp + floor_offset)); + + ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + + result[i * 4] = static_cast(value_0); + result[i * 4 + 1] = static_cast(value_1); + result[i * 4 + 2] = static_cast(value_2); + result[i * 4 + 3] = static_cast(value_3); + } + + // 预定义的固定值数组(64个元素) + const int fixed_values[64] = { + 0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57, + 2, 3, 10, 11, 18, 19, 26, 27, 34, 35, 42, 43, 50, 51, 58, 59, + 4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61, + 6, 7, 14, 15, 22, 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63 + }; + + // CUTLASS_PRAGMA_UNROLL + // for (int i = 0; i < 16; ++i) { + // // result[i] = static_cast(fixed_values[i + idx * 16]); + // if (threadIdx.x % 32 == 0 || threadIdx.x % 32 == 4) { + // result[i] = static_cast(fixed_values0[i + idx * 16]); + // } else if (threadIdx.x % 32 == 1 || threadIdx.x % 32 == 5) { + // result[i] = static_cast(fixed_values1[i + idx * 16]); + // } else if (threadIdx.x % 32 == 2 || threadIdx.x % 32 == 6) { + // result[i] = static_cast(fixed_values2[i + idx * 16]); + // } else if (threadIdx.x % 32 == 3 || threadIdx.x % 32 == 7) { + // result[i] = static_cast(fixed_values3[i + idx * 16]); + // } else { + // result[i] = static_cast(0); + // } + // } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 16; ++i) { + result[i] = static_cast(fixed_values[i + (threadIdx.x % 4) * 16]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h index 9e1c6c463b..fa28810697 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h @@ -125,10 +125,13 @@ struct WintQuantTraits { static constexpr int32_t kNumPackedValues = 4; static constexpr int32_t kPackedSize = 16; + using LocalScaleType = uint4b_t; + using CodeScaleZpType = float; + struct Arguments { - const uint8_t *local_scale_ptr; // quanted 4-bits - const float *code_scale_ptr; - const float *code_zp_ptr; + uint8_t *local_scale_ptr; // quanted 4-bits + float *code_scale_ptr; + float *code_zp_ptr; }; CUTLASS_DEVICE diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 356f305968..611888591e 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -43,7 +43,6 @@ #include "cutlass/trace.h" #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" #include "cutlass_extensions/tile_interleaved_layout.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -52,6 +51,15 @@ namespace cutlass { namespace gemm { namespace kernel { +template std::string GetCutlassLayoutString() { + if (std::is_same::value) { + return "cutlass::layout::RowMajor"; + } else if (std::is_same::value) { + return "cutlass::layout::ColumnMajor"; + } + return "unknown"; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // This section exists to that we can use the same kernel code for regular gemm // and dequantizing gemms. It will dispatch to the dequantizing gemm if the Mma @@ -282,6 +290,27 @@ struct MoeFCGemm { platform::is_same::value) { assert(weight_scales); } + + CUTLASS_TRACE_HOST("[Arguments] problem_count: " << problem_count << ", threadblock_count: " << threadblock_count << ", gemm_n: " << gemm_n << ", gemm_k: " << gemm_k); + CUTLASS_TRACE_HOST("[Arguments] ptr_A: " << static_cast(ptr_A)); + CUTLASS_TRACE_HOST("[Arguments] ptr_B: " << static_cast(ptr_B)); + CUTLASS_TRACE_HOST("[Arguments] ptr_C: " << static_cast(ptr_C)); + CUTLASS_TRACE_HOST("[Arguments] ptr_D: " << static_cast(ptr_D)); + CUTLASS_TRACE_HOST("[Arguments] weight_scales: " << static_cast(weight_scales)); + CUTLASS_TRACE_HOST("[Arguments] total_rows_before_expert: " << static_cast(total_rows_before_expert)); + CUTLASS_TRACE_HOST("[Arguments] local_scale: " << static_cast(local_scale)); + CUTLASS_TRACE_HOST("[Arguments] code_scale: " << static_cast(code_scale)); + CUTLASS_TRACE_HOST("[Arguments] code_zp: " << static_cast(code_zp)); + CUTLASS_TRACE_HOST("[Arguments] quant_method: " << static_cast(quant_method)); + CUTLASS_TRACE_HOST("[Arguments] LayoutA: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] LayoutB: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] LayoutC: " << GetCutlassLayoutString()); + CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorA::AccessType::kElements:" << Mma::IteratorA::AccessType::kElements); + CUTLASS_TRACE_HOST("[Arguments] Mma::IteratorB::AccessType::kElements:" << Mma::IteratorB::AccessType::kElements); + CUTLASS_TRACE_HOST("[Arguments] SharedStorage Information:"); + CUTLASS_TRACE_HOST(" - ProblemVisitor::SharedStorage: " << sizeof(typename ProblemVisitor::SharedStorage) << " bytes"); + CUTLASS_TRACE_HOST(" - Mma::SharedStorage: " << sizeof(typename Mma::SharedStorage) << " bytes"); + CUTLASS_TRACE_HOST(" - Epilogue::SharedStorage: " << sizeof(typename Epilogue::SharedStorage) << " bytes"); } }; @@ -775,17 +804,54 @@ struct Wint2xMoeFCGemm : public MoeFCGemm struct KernelRunner { using WeightQuantTraits = WintQuantTraits; - using QuantArguments = typename WeightQuantTraits::Arguments; + using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments; CUTLASS_DEVICE - static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) { - QuantArguments quant_args; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128; - quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n; - quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n; - } - return quant_args; + static MmaQuantArguments prepare_quant_args( + Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset, + int32_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) { + // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2}; + + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale( + Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n), + weight_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2); + uint4b_t *local_scale_ptr = reinterpret_cast(params.local_scale + problem_idx * gemm_k * gemm_n / 128); + CUTLASS_TRACE_DEVICE(" local_scale_ptr=%p, extent={%d, %d}, tb_offset={%d, %d}, local_scale_pointer_offset=%d", + local_scale_ptr, gemm_k / 128, gemm_n * 2, tb_offset_local_scale.row(), tb_offset_local_scale.column(), local_scale_pointer_offset); + typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale( + Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2), + local_scale_ptr, + {(gemm_k + 127) / 128, gemm_n * 2}, + thread_idx, + tb_offset_local_scale); + + float* code_scale_ptr = params.code_scale + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + float* code_zp_ptr = params.code_zp + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_zp_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + MmaQuantArguments mma_quant_args( + iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset); + return mma_quant_args; } CUTLASS_DEVICE @@ -814,9 +880,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 1, "B must be row major/col major OR col major interleaved."); - // LayoutB should be RowMajor - using TileDequanterB = cutlass::gemm::threadblock::TileDequanter; - // // Problem visitor. // @@ -835,6 +898,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm 2) { + break; + } + + CUTLASS_TRACE_DEVICE(" problem_idx: %d, cta_idx: %d, problem_size: {%d, %d, %d}", + problem_idx, cta_idx, static_cast(problem_size.m()), static_cast(problem_size.n()), static_cast(problem_size.k())); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); // threadblock_offset of C @@ -843,12 +913,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(byte_ptr_B); typename LayoutB::LongIndex ldm_B = platform::is_same::value ? gemm_n : gemm_k * kInterleave; - typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; // the begin threadblock_offset of B, which holds the same column id with C - cutlass::MatrixCoord tb_offset_B{0, - threadblock_offset.n() / kInterleave}; - + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; - MmaElementB* smem_unzip_B_ptr = nullptr; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); - } - QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n); - TileDequanterB tile_dequanter_B(smem_unzip_B_ptr, - byte_ptr_B, - ldm_B, - extent_B, - tb_offset_B, - weight_scale_ptr, - tb_offset_scale, - quant_args); - MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr(); + CUTLASS_TRACE_DEVICE(" ldm_B: %d, tb_offset_B: {%d, %d}, extent_B: {%d, %d}", + static_cast(ldm_B), tb_offset_B.row(), tb_offset_B.column(), extent_B.row(), extent_B.column()); // Compute position within threadblock int thread_idx = threadIdx.x; @@ -914,20 +959,21 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::Op; + CUTLASS_TRACE_HOST("Stages: " << Stages); + // Finally, set up the kernel. using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< ElementType, @@ -187,6 +189,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, "GroupedGEMM kernel"); } const int threadblock_count = multi_processor_count * occupancy; + CUTLASS_TRACE_HOST("kernel_occupancy: " << kernel_occupancy << ", occupancy: " << occupancy << ", threadblock_count: " << threadblock_count << ", multi_processor_count: " << multi_processor_count); typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); @@ -205,7 +208,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), + reinterpret_cast(B), reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), @@ -443,6 +446,7 @@ void dispatch_gemm_config(const T* A, #define dispatch_gemm_config_macro(AA, BB, CC, DD, EE, FF) \ case CutlassTileConfig:: \ CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \ + CUTLASS_TRACE_HOST("ThreadblockShape<" << AA << "," << BB << "," << CC << ">, WarpShape<" << DD << "," << EE << "," << FF << ">"); \ dispatch_gemm_config::value) { if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) { switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; @@ -563,16 +567,16 @@ void dispatch_moe_gemm_to_cutlass(const T* A, } else { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64); - dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64); - dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64); - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); - dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64); - dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64); - dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); - dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64); + //dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64); + //dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + //dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64); + //dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64); + //dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64); + //dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); + //dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; @@ -614,7 +618,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); + //dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); case CutlassTileConfig::Undefined: throw std::runtime_error( "[dispatch_moe_gemm_to_cutlass][SIMT] gemm config " @@ -711,8 +715,8 @@ void MoeGemmRunner::run_gemm( std::vector candidate_configs = get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true); - static constexpr int warm_time = 5; - static constexpr int test_time = 10; + static constexpr int warm_time = 0; + static constexpr int test_time = 1; auto& gemmConfigManager = GemmConfigManager::Instance(); constexpr GemmDataType dtype = getGemmDataType(); constexpr GemmDataType wdtype = getGemmDataType(); @@ -731,8 +735,10 @@ void MoeGemmRunner::run_gemm( std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows), gemmConfigManager.getMaxProfileM()); bool find_one = false; - size_t num_candidate_configs_size = candidate_configs.size(); - for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) { + size_t num_candidate_configs_size = 2;//candidate_configs.size(); + // for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) + { + size_t ii = 1; try { for (int i = 0; i < warm_time; i++) { dispatch_to_arch(A, @@ -776,7 +782,7 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); check_cuda_error(cudaEventDestroy(start)); check_cuda_error(cudaEventDestroy(stop)); - //std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; + std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; if (elapsed < best_time) { best_id = ii; best_time = elapsed; @@ -797,6 +803,7 @@ void MoeGemmRunner::run_gemm( } } +#if 0 dispatch_to_arch(A, B, weight_scales, @@ -810,6 +817,7 @@ void MoeGemmRunner::run_gemm( quant_args_B, chosen_config, stream); +#endif } template diff --git a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu index fb9d2e69fe..616d21ba42 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu @@ -20,6 +20,8 @@ #include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" +#define _GROUP_GEMM_ONLY 1 + template void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, @@ -49,12 +51,13 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, typename WeightOnlyTraits::Arguments ffn1_quant_args; typename WeightOnlyTraits::Arguments ffn2_quant_args; if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { - ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data(); - ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data(); - ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data(); - ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data(); - ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data(); - ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data(); + ffn1_quant_args.local_scale_ptr = const_cast(ffn1_local_scale->data()); + ffn1_quant_args.code_scale_ptr = const_cast(ffn1_code_scale->data()); + ffn1_quant_args.code_zp_ptr = const_cast(ffn1_code_zp->data()); + + ffn2_quant_args.local_scale_ptr = const_cast(ffn2_local_scale->data()); + ffn2_quant_args.code_scale_ptr = const_cast(ffn2_code_scale->data()); + ffn2_quant_args.code_zp_ptr = const_cast(ffn2_code_zp->data()); } auto moe_gemm_runner = MoeGemmRunner(); @@ -65,7 +68,11 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, reinterpret_cast(ffn1_weight.data()), reinterpret_cast(ffn1_super_scale ? ffn1_super_scale->data() : nullptr), reinterpret_cast(ffn1_bias ? ffn1_bias->data() : nullptr), +#if _GROUP_GEMM_ONLY + reinterpret_cast(ffn_out.data()), +#else reinterpret_cast(fc1_out.data()), +#endif const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, actual_total_rows, @@ -76,9 +83,12 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, "none", stream); +#if _GROUP_GEMM_ONLY + // do nothing +#else paddle::Tensor act_out; if (used_in_ep_low_latency) { - act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); + //act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum); } else { act_out = paddle::experimental::swiglu(fc1_out, nullptr); } @@ -96,6 +106,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, num_experts, ffn2_quant_args, stream); +#endif } template @@ -198,7 +209,14 @@ paddle::Tensor MoeExpertFFNWint2Func( const bool used_in_ep_low_latency) { const auto dtype = permute_input.dtype(); +#if _GROUP_GEMM_ONLY + auto place = permute_input.place(); + int64_t expanded_active_expert_rows = permute_input.dims()[0]; + int64_t inter_size = ffn1_scale.get().dims()[1]; + auto ffn_out = GetEmptyTensor({expanded_active_expert_rows, inter_size}, dtype, place); +#else auto ffn_out = paddle::empty_like(permute_input, dtype); +#endif switch (dtype) { case paddle::DataType::BFLOAT16: @@ -289,7 +307,14 @@ std::vector> MoeExpertFFNWint2InferShape( const paddle::optional>& ffn2_code_zp_shape, const bool used_in_ep_low_latency) { +#if _GROUP_GEMM_ONLY + int64_t expanded_active_expert_rows = permute_input_shape[0]; + int64_t inter_size = ffn1_scale_shape.get()[1]; + std::cout << "expanded_active_expert_rows: " << expanded_active_expert_rows << ", inter_size: " << inter_size << std::endl; + return {std::vector{expanded_active_expert_rows, inter_size}}; +#else return {permute_input_shape}; +#endif } std::vector MoeExpertFFNWint2InferDtype( @@ -356,7 +381,7 @@ std::vector MoeExpertFFNWint2InferDtype( * Note: * - Low latency mode uses specialized grouped SwiGLU implementation */ -PD_BUILD_STATIC_OP(moe_expert_ffn_wint2) +PD_BUILD_OP(moe_expert_ffn_wint2) .Inputs({"permute_input", "tokens_expert_prefix_sum", "ffn1_weight",