@@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
55
55
}
56
56
}
57
57
58
- TLLM_CHECK_WITH_INFO (mPassingConfigIndices .size () != 0 , " No kernel found for the given output type " );
58
+ TLLM_CHECK_WITH_INFO (! mPassingConfigIndices .empty () , " No kernel found for the given options " );
59
59
}
60
60
61
61
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes (int32_t m, int32_t n, int32_t k,
62
- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
62
+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
63
+ std::optional<int32_t > configIndex)
63
64
{
64
65
BatchedGemmData gemmData;
65
66
gemmData.mProblemDimensions .mNumBatches = numBatches;
@@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
74
75
gemmData.mProblemDimensions .mWorldSize = 1 ;
75
76
gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
76
77
77
- selectGemmConfig (m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
78
-
79
78
auto bmm = BatchedGemmInterface ();
79
+
80
80
auto const configs = bmm.getBatchedGemmConfigs ();
81
- TLLM_CHECK_WITH_INFO (
82
- mSelectedConfigIndex .has_value (), " No valid kernel found for given param config and problem size" );
83
- auto const & config = configs[mSelectedConfigIndex .value ()];
81
+
82
+ if (!configIndex.has_value ())
83
+ {
84
+ mSelectedConfigIndex
85
+ = getDefaultValidConfigIndex (m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
86
+ configIndex = mSelectedConfigIndex ;
87
+ }
88
+
89
+ auto const & config = configs[configIndex.value ()];
84
90
return bmm.getWorkspaceSizeInBytes (config, gemmData);
85
91
}
86
92
@@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
89
95
void const * sfB, void const * perTokensSfA, void const * perTokensSfB, float const * scaleC, float const * scaleGateC,
90
96
void * c, void * outSfC, int32_t const * routeMap, int32_t const * totalNumPaddedTokens,
91
97
int32_t const * ctaIdxXyToBatchIdx, int32_t const * ctaIdxXyToMnLimit, int32_t const * numNonExitingCtas,
92
- void * workspace, CUstream stream, int device)
98
+ void * workspace, CUstream stream, int device, std::optional< int32_t > configIndex )
93
99
{
94
100
auto bmm = BatchedGemmInterface ();
95
101
96
102
BatchedGemmData gemmData;
97
103
98
104
auto const configs = bmm.getBatchedGemmConfigs ();
99
- TLLM_CHECK_WITH_INFO (
100
- mSelectedConfigIndex .has_value (), " No valid kernel found for given param config and problem size" );
101
- auto const & config = configs[mSelectedConfigIndex .value ()];
105
+
106
+ if (!configIndex.has_value ())
107
+ {
108
+ TLLM_CHECK_WITH_INFO (mSelectedConfigIndex .has_value (), " Tried to use default config index but none was set" );
109
+
110
+ configIndex = mSelectedConfigIndex ;
111
+ }
112
+
113
+ auto const & config = configs[configIndex.value ()];
102
114
103
115
TLLM_CHECK_WITH_INFO (numBatches > 0 , " Batched GEMM requires numBatches > 0" );
104
116
if (!mOptions .staticBatch )
@@ -170,32 +182,33 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
170
182
171
183
void TrtllmGenBatchedGemmRunner::run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens,
172
184
void const * a, void const * sfA, void const * b, void const * sfB, void * c, void * outSfC, void * workspace,
173
- CUstream stream, int device)
185
+ CUstream stream, int device, std::optional< int32_t > configIndex )
174
186
{
175
187
// Dispatch with block scaling factors and with static batching.
176
188
run (m, n, k, batchedTokens, /* numTokens */ 0 , batchedTokens.size (), /* maxNumCtasInBatchDim */ 0 , a, sfA, b, sfB,
177
189
/* perTokensSfA */ nullptr , /* perTokensSfB */ nullptr ,
178
190
/* scaleC */ nullptr , /* scaleGateC */ nullptr , c, outSfC,
179
191
/* routeMap */ nullptr , /* totalNumPaddedTokens */ nullptr ,
180
192
/* ctaIdxXyToBatchIdx */ nullptr , /* ctaIdxXyToMnLimit */ nullptr ,
181
- /* numNonExitingCtas */ nullptr , workspace, stream, device);
193
+ /* numNonExitingCtas */ nullptr , workspace, stream, device, configIndex );
182
194
}
183
195
184
196
void TrtllmGenBatchedGemmRunner::run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens,
185
197
void const * a, void const * b, float const * scaleC, float const * scaleGateC, void * c, void * workspace,
186
- CUstream stream, int device)
198
+ CUstream stream, int device, std::optional< int32_t > configIndex )
187
199
{
188
200
// Dispatch with block scaling factors and with static batching.
189
201
run (m, n, k, batchedTokens, /* numTokens */ 0 , batchedTokens.size (), /* maxNumCtasInBatchDim */ 0 , a,
190
202
/* sfA */ nullptr , b, /* sfB */ nullptr , /* perTokensSfA */ nullptr , /* perTokensSfB */ nullptr , scaleC,
191
203
scaleGateC, c, /* outSfC */ nullptr ,
192
204
/* routeMap */ nullptr , /* totalNumPaddedTokens */ nullptr ,
193
205
/* ctaIdxXyToBatchIdx */ nullptr , /* ctaIdxXyToMnLimit */ nullptr ,
194
- /* numNonExitingCtas */ nullptr , workspace, stream, device);
206
+ /* numNonExitingCtas */ nullptr , workspace, stream, device, configIndex );
195
207
}
196
208
197
- void TrtllmGenBatchedGemmRunner::selectGemmConfig (int32_t m, int32_t n, int32_t k,
198
- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
209
+ std::vector<int64_t > TrtllmGenBatchedGemmRunner::getValidConfigIndices (int32_t m, int32_t n, int32_t k,
210
+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
211
+ int32_t maxNumCtasInBatchDim) const
199
212
{
200
213
auto const bmm = BatchedGemmInterface ();
201
214
auto const configs = bmm.getBatchedGemmConfigs ();
@@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
242
255
return optionsA.mTileM > optionsB.mTileM ;
243
256
});
244
257
258
+ std::vector<int64_t > validConfigIndices;
245
259
for (auto const & configIndex : sortedIndices)
246
260
{
247
261
auto const & config = configs[configIndex];
248
262
auto isValidConfig = bmm.isValidConfig (config, gemmData);
249
263
if (isValidConfig)
250
264
{
251
- mSelectedConfigIndex = configIndex;
252
- return ;
265
+ validConfigIndices.push_back (configIndex);
253
266
}
254
267
}
268
+
269
+ TLLM_CHECK_WITH_INFO (!validConfigIndices.empty (), " No valid config found for the given problem shape" );
270
+
271
+ return validConfigIndices;
272
+ }
273
+
274
+ int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex (int32_t m, int32_t n, int32_t k,
275
+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
276
+ int32_t maxNumCtasInBatchDim) const
277
+ {
278
+ auto const validConfigIndices
279
+ = getValidConfigIndices (m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
280
+
281
+ return validConfigIndices[0 ];
255
282
}
256
283
257
284
} // namespace kernels
0 commit comments