Skip to content

Commit 9c012d5

Browse files
authored
[TRTLLM-5589] feat: Integrate TRT-LLM Gen FP8 Batched GEMM with Pytorch workflow kernel autotuner (#4872)
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
1 parent 1d4f748 commit 9c012d5

File tree

6 files changed

+618
-174
lines changed

6 files changed

+618
-174
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
5555
}
5656
}
5757

58-
TLLM_CHECK_WITH_INFO(mPassingConfigIndices.size() != 0, "No kernel found for the given output type");
58+
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
5959
}
6060

6161
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
62-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
62+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
63+
std::optional<int32_t> configIndex)
6364
{
6465
BatchedGemmData gemmData;
6566
gemmData.mProblemDimensions.mNumBatches = numBatches;
@@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
7475
gemmData.mProblemDimensions.mWorldSize = 1;
7576
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
7677

77-
selectGemmConfig(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
78-
7978
auto bmm = BatchedGemmInterface();
79+
8080
auto const configs = bmm.getBatchedGemmConfigs();
81-
TLLM_CHECK_WITH_INFO(
82-
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
83-
auto const& config = configs[mSelectedConfigIndex.value()];
81+
82+
if (!configIndex.has_value())
83+
{
84+
mSelectedConfigIndex
85+
= getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
86+
configIndex = mSelectedConfigIndex;
87+
}
88+
89+
auto const& config = configs[configIndex.value()];
8490
return bmm.getWorkspaceSizeInBytes(config, gemmData);
8591
}
8692

@@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
8995
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
9096
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
9197
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
92-
void* workspace, CUstream stream, int device)
98+
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex)
9399
{
94100
auto bmm = BatchedGemmInterface();
95101

96102
BatchedGemmData gemmData;
97103

98104
auto const configs = bmm.getBatchedGemmConfigs();
99-
TLLM_CHECK_WITH_INFO(
100-
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
101-
auto const& config = configs[mSelectedConfigIndex.value()];
105+
106+
if (!configIndex.has_value())
107+
{
108+
TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set");
109+
110+
configIndex = mSelectedConfigIndex;
111+
}
112+
113+
auto const& config = configs[configIndex.value()];
102114

103115
TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
104116
if (!mOptions.staticBatch)
@@ -170,32 +182,33 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
170182

171183
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
172184
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
173-
CUstream stream, int device)
185+
CUstream stream, int device, std::optional<int32_t> configIndex)
174186
{
175187
// Dispatch with block scaling factors and with static batching.
176188
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
177189
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
178190
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
179191
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
180192
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
181-
/* numNonExitingCtas */ nullptr, workspace, stream, device);
193+
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
182194
}
183195

184196
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
185197
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
186-
CUstream stream, int device)
198+
CUstream stream, int device, std::optional<int32_t> configIndex)
187199
{
188200
// Dispatch with block scaling factors and with static batching.
189201
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
190202
/* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, scaleC,
191203
scaleGateC, c, /* outSfC */ nullptr,
192204
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
193205
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
194-
/* numNonExitingCtas */ nullptr, workspace, stream, device);
206+
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
195207
}
196208

197-
void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k,
198-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
209+
std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m, int32_t n, int32_t k,
210+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
211+
int32_t maxNumCtasInBatchDim) const
199212
{
200213
auto const bmm = BatchedGemmInterface();
201214
auto const configs = bmm.getBatchedGemmConfigs();
@@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
242255
return optionsA.mTileM > optionsB.mTileM;
243256
});
244257

258+
std::vector<int64_t> validConfigIndices;
245259
for (auto const& configIndex : sortedIndices)
246260
{
247261
auto const& config = configs[configIndex];
248262
auto isValidConfig = bmm.isValidConfig(config, gemmData);
249263
if (isValidConfig)
250264
{
251-
mSelectedConfigIndex = configIndex;
252-
return;
265+
validConfigIndices.push_back(configIndex);
253266
}
254267
}
268+
269+
TLLM_CHECK_WITH_INFO(!validConfigIndices.empty(), "No valid config found for the given problem shape");
270+
271+
return validConfigIndices;
272+
}
273+
274+
int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
275+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
276+
int32_t maxNumCtasInBatchDim) const
277+
{
278+
auto const validConfigIndices
279+
= getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
280+
281+
return validConfigIndices[0];
255282
}
256283

257284
} // namespace kernels

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
#pragma once
1818

19+
#include <cstdint>
1920
#include <cuda.h>
2021
#include <optional>
22+
#include <vector>
2123

2224
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2325

@@ -45,29 +47,49 @@ class TrtllmGenBatchedGemmRunner
4547
explicit TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options);
4648

4749
[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
48-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim);
50+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
51+
std::optional<int32_t> configIndex = std::nullopt);
4952

5053
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
5154
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
5255
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
5356
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
5457
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
55-
void* workspace, CUstream stream, int device);
58+
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);
5659

5760
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
58-
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device);
61+
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
62+
std::optional<int32_t> configIndex = std::nullopt);
5963

6064
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
61-
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device);
65+
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
66+
std::optional<int32_t> configIndex = std::nullopt);
67+
68+
// Get the list of configs that passed the validation based on the constructor options
69+
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
70+
{
71+
return mPassingConfigIndices;
72+
}
73+
74+
// Get the list of config indices that are valid for the given problem shape
75+
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
76+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
77+
int32_t maxNumCtasInBatchDim) const;
78+
79+
// Get a default config index that is valid for the given problem shape
80+
// This will be used as the fallback config if using auto-tuning
81+
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
82+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
83+
int32_t maxNumCtasInBatchDim) const;
6284

6385
private:
6486
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
6587
int32_t numBatches, int32_t maxNumCtasInBatchDim);
6688

6789
private:
6890
TrtllmGenBatchedGemmRunnerOptions mOptions;
69-
std::optional<int> mSelectedConfigIndex;
7091
std::vector<int32_t> mPassingConfigIndices;
92+
std::optional<int32_t> mSelectedConfigIndex;
7193
};
7294
} // namespace kernels
7395
} // namespace tensorrt_llm

0 commit comments

Comments
 (0)