Skip to content

feat: MoE trtllm backend kernel update #5183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 139 additions & 5 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,138 @@ namespace kernels
{

using namespace batchedGemm::batchedGemm;
using namespace batchedGemm::gemm;
using namespace batchedGemm::trtllm::gen;

std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
{

// Function to bubble up the pre-determined config.
auto bubbleUpConfig = [&configs](std::vector<int64_t> const& sortedIndices, auto&& pred) -> std::vector<int64_t>
{
std::vector<int64_t> prioritizedIndices_;
// Copy matching configs to new vector
std::copy_if(sortedIndices.begin(), sortedIndices.end(), std::back_inserter(prioritizedIndices_),
[&configs, &pred](int idx)
{
BatchedGemmConfig const& config = configs[idx];
return (pred(config));
});
// Copy the rest of the configs to new vector, if not already copied
std::copy_if(sortedIndices.begin(), sortedIndices.end(), std::back_inserter(prioritizedIndices_),
[&prioritizedIndices_](int idx) {
return std::find(prioritizedIndices_.begin(), prioritizedIndices_.end(), idx)
== prioritizedIndices_.end();
});
return prioritizedIndices_;
};

// Init empty vector
std::vector<int64_t> prioritizedIndices;

//
// Qwen3
//

// Qwen3_235B_TP1_EP8_MoE_FC1 m=3072 k=4096
if (n /* out_dim */ == 3072 && k /* in_dim */ == 4096)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Static;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP1_EP8_MoE_FC2 m=4096 k=1536
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 1536)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Static;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP2_EP4_MoE_FC1 m=1536 k=4096
else if (n /* out_dim */ == 1536 && k /* in_dim */ == 4096)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Static;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP2_EP4_MoE_FC2 m=4096 k=768
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 768)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Persistent;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP4_EP2_MoE_FC1 m=768 k=4096
else if (n /* out_dim */ == 768 && k /* in_dim */ == 4096)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Static;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP4_EP2_MoE_FC2 m=4096 k=384
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 384)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Persistent;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP8_EP1_MoE_FC1 m=384 k=4096
else if (n /* out_dim */ == 384 && k /* in_dim */ == 4096)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
&& options.mTileScheduler == TileScheduler::Static;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
// Qwen3_235B_TP8_EP1_MoE_FC2 m=4096 k=192
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 192)
{
auto pred = [](BatchedGemmConfig const& config)
{
BatchedGemmOptions const& options = config.mOptions;
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256
&& options.mTileScheduler == TileScheduler::Persistent;
};
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
}
//
// Fall back
//
else
{
prioritizedIndices = sortedIndices;
}

return prioritizedIndices;
}

TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options_)
: mOptions(options_)
Expand All @@ -44,7 +176,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
// When we include low-latency kernels we can set transposeMmaOutput via constructor
if (options.mDtypeA == mOptions.eltType && options.mDtypeC == mOptions.outputType
&& options.mUseDeepSeekFp8 == mOptions.deepSeekFp8
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput && options.mRouteAct == mOptions.routeAct
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput
&& (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct
&& options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch
&& tileSize == mOptions.tileSize)
{
Expand Down Expand Up @@ -227,9 +360,9 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
// Sort configs by options
std::vector<int32_t> sortedIndices = mPassingConfigIndices;
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
std::sort(sortedIndices.begin(), sortedIndices.end(),
[&configs](int32_t idx0, int32_t idx1)
[&configs](int64_t idx0, int64_t idx1)
{
auto const& optionsA = configs[idx0].mOptions;
auto const& optionsB = configs[idx1].mOptions;
Expand All @@ -247,16 +380,17 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
}

// Then by tile scheduler (persistent scheduler is better for FC2 in MoE)
if (!optionsA.mRouteAct)
if (doesRouteImplUseNoRoute(optionsA.mRouteImpl))
{
return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
}

return optionsA.mTileM > optionsB.mTileM;
});

std::vector<int64_t> prioritizedIndices = prioritizePredefinedConfigs(m, n, k, sortedIndices, configs);
std::vector<int64_t> validConfigIndices;
for (auto const& configIndex : sortedIndices)
for (auto const& configIndex : prioritizedIndices)
{
auto const& config = configs[configIndex];
auto isValidConfig = bmm.isValidConfig(config, gemmData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,26 @@ class TrtllmGenBatchedGemmRunner
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex = std::nullopt);

// Generic GEMM interface
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);

// NVFP4 per-block scaling GEMM
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);

// FP8 per-tensor scaling GEMM
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);

// Get the list of configs that passed the validation based on the constructor options
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const
{
return mPassingConfigIndices;
}
Expand All @@ -88,8 +91,8 @@ class TrtllmGenBatchedGemmRunner

private:
TrtllmGenBatchedGemmRunnerOptions mOptions;
std::vector<int32_t> mPassingConfigIndices;
std::optional<int32_t> mSelectedConfigIndex;
std::vector<int64_t> mPassingConfigIndices;
std::optional<int64_t> mSelectedConfigIndex;
};
} // namespace kernels
} // namespace tensorrt_llm
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cassert>
#include <string>

namespace batchedGemm
{

namespace batchedGemm
{

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class RouteImpl
{
// No Routing
NoRoute = 0,
// Use LDGSTS to do the routing
Ldgsts = 1,
// Use UTMALDG.GATHER4 to do the routing
Tma = 2
};

////////////////////////////////////////////////////////////////////////////////////////////////////

inline bool doesRouteImplUseNoRoute(RouteImpl mode)
{
return (mode == RouteImpl::NoRoute);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline bool doesRouteImplUseLdgsts(RouteImpl mode)
{
return (mode == RouteImpl::Ldgsts);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline bool doesRouteImplUseTma(RouteImpl mode)
{
return (mode == RouteImpl::Tma);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace batchedGemm

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace batchedGemm
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ struct BatchedGemmData
// Logical strides are [K, 1].
//
// If batchN:
// If transposeMatrixA is false
// If layoutA is MatrixLayout::MajorK
// Logical shape is [B, divUpMul(M, tileM), K].
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
// If transposeMatrixA is true
// If layoutA is MatrixLayout::MajorMn
// Logical shape is [B, K, divUpMul(M, tileM)].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM), 1].
// If layoutA is MatrixLayout::BlockMajorK
// Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1].
// where blockK is 128B.
void const* mPtrA{nullptr};

// The block scaling factors to dequantize A.
Expand Down Expand Up @@ -154,12 +158,16 @@ struct BatchedGemmData
// Logical strides are [K, 1].
//
// If batchM:
// If transposeMatrixB is true
// If layoutB is MatrixLayout::MajorK
// Logical shape is [B, divUpMul(N, tileN), K].
// Logical strides are [divUpMul(N, tileN) * K, K, 1].
// If transposeMatrixB is false
// If layoutB is MatrixLayout::MajorMn
// Logical shape is [B, K, divUpMul(N, tileN)].
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN), 1].
// If layoutB is MatrixLayout::BlockMajorK
// Logical shape is [B, K / blockK, divUpMul(N, tileN), blockK].
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN) * blockK, blockK, 1].
// where blockK is 128B.
void const* mPtrB{nullptr};

// The scaling factors to dequantize B.
Expand Down Expand Up @@ -210,6 +218,21 @@ struct BatchedGemmData
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B)]
void const* mPtrPerTokenSfB{nullptr};

// The bias applied after the GEMM and before the activation function.
// The bias is applied before applying the global scaling factor. I.e.
// C = act(A * B + bias') * scaleC
// scaleC = dequantA * dequantB * quantC
// Thus, the bias' = bias / (dequantA * dequantB), where the bias is the original bias.
//
// If batchM, BiasType must be N, and bias shape is [B, N].
// The bias is broadcasted along the M dimension.
//
// If batchN BiasType must be M, and bias shape is [B, M].
// The bias is broadcasted along the N dimension.
//
// The dtype is float32.
void const* mPtrBias{nullptr};

// The output tensor scaling factor for MxFp{4,8}, Fp8 and NvFp4 quantization.
// TensorRT-LLM API requires a scaling factor on the device.
// Shape is [B].
Expand All @@ -220,6 +243,12 @@ struct BatchedGemmData
// Shape is [B].
float const* mPtrScaleGate{nullptr};

// The alpha and beta for SwiGlu.
// gatedActivation <- (x0 + beta) * sigmoid(alpha * x1)
// Shape is [B]
float const* mPtrSwiGluAlpha{nullptr};
float const* mPtrSwiGluBeta{nullptr};

// Param is used when the kernel is configured with -routeAct true.
// The inputs are not padded, but the outputs are padded to divUpMul(M[bi], tileM) for batchM or
// divUpMul(N[bi], tileN) for batchN.
Expand Down Expand Up @@ -609,11 +638,13 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
batchedGemmData.mInputBuffers.mPtrB, batchedGemmData.mOutputBuffers.mPtrC,
batchedGemmData.mInputBuffers.mPtrSfA, batchedGemmData.mInputBuffers.mPtrSfB,
batchedGemmData.mInputBuffers.mPtrPerTokenSfA, batchedGemmData.mInputBuffers.mPtrPerTokenSfB,
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax,
dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, maxNumCtasInBatchDim);
batchedGemmData.mInputBuffers.mPtrBias, batchedGemmData.mOutputBuffers.mPtrSfC,
batchedGemmData.mInputBuffers.mPtrScaleC, batchedGemmData.mInputBuffers.mPtrScaleGate,
batchedGemmData.mInputBuffers.mPtrSwiGluAlpha, batchedGemmData.mInputBuffers.mPtrSwiGluBeta,
batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax, dPtrRowMaxBars,
batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit,
maxNumCtasInBatchDim);

// The size of the grid.
std::vector<int32_t> grid{numCtaX, numCtaY, numCtaZ};
Expand Down
Loading