Skip to content

Commit 771e71a

Browse files
celsowmgoogle-labs-jules[bot]Jiang-Jia-Jun
authored
Feat/blackwell sm100 support (#2670)
* Add initial support for NVIDIA Blackwell (SM100) architecture This change introduces initial support for the NVIDIA Blackwell GPU architecture, specifically targeting SM100 (Compute Capability 10.x) with '100a' architecture-specific features (e.g., for CUTLASS). Key changes: - Updated custom_ops/setup_ops.py to generate appropriate gencode flags (arch=compute_100a,code=sm_100a) when '100' is specified in FD_BUILDING_ARCS. Requires CUDA 12.9+. - Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h: - Added CutlassTileConfigSM100 enum (with placeholder tile shapes). - Added BLACKWELL to CandidateConfigTypeParam. - Updated CutlassGemmConfig struct with is_sm100 flag, tile_config_sm100, and new constructor for SM100. - Modified toString() and fromString() for SM100 support. - Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu: - Added get_candidate_tiles_sm100() (with placeholder tiles). - Added placeholder mcast support functions for SM100. - Updated get_candidate_configs() to include SM100 paths using the BLACKWELL flag and new SM100 config types. - Updated build.sh with comments to guide users on specifying '100' for Blackwell in FD_BUILDING_ARCS. Further work: - Optimal CUTLASS tile configurations for SM100 need to be researched and updated in cutlass_heuristic.cu. - Kernel auto-generation scripts in custom_ops/utils/ may need SM100-specific versions if Blackwell's hardware features for FP8/TMA differ significantly from SM90. - Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM) with Blackwell should be fully verified. * Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics This change integrates specific, expert-provided CUTLASS heuristic configurations for the NVIDIA Blackwell (SM100) GPU architecture, replacing previous placeholders. This includes: - Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`: - Populated `CutlassTileConfigSM100` enum with specific tile shapes (e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100. - Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths. - Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`: - Implemented `get_candidate_tiles_sm100` with detailed logic for selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags, using the new SM100 tile enums. - Implemented `supports_mcast_along_m_sm100` and `supports_mcast_along_n_sm100` with specific tile checks for Blackwell. - Updated the `sm == 100` (Blackwell) block in `get_candidate_configs` to use these new helper functions and accurately populate candidate kernel configurations for various cluster shapes. - `custom_ops/setup_ops.py` remains configured to compile for `arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features. This aligns the codebase with heuristic configurations similar to those in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more performant kernel selection on this new architecture. --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
1 parent 0350831 commit 771e71a

File tree

4 files changed

+308
-52
lines changed

4 files changed

+308
-52
lines changed

build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
1818
PYTHON_VERSION=${2:-"python"}
1919
export python=$PYTHON_VERSION
2020
FD_CPU_USE_BF16=${3:-"false"}
21+
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
22+
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
23+
# These will be translated to 90a / 100a in setup_ops.py for specific features.
2124
FD_BUILDING_ARCS=${4:-""}
2225

2326

custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h

Lines changed: 107 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,34 @@ enum class SplitKStyle
7676
// SPLIT_K_PARALLEL // Not supported yet
7777
};
7878

79+
// New enum for SM100 (Blackwell) Tile Configs
80+
// Placeholder values - actual optimal values need research
81+
enum class CutlassTileConfigSM100
82+
{
83+
// Signals that we should run heuristics do choose a config
84+
Undefined,
85+
86+
// Signals that we should run heuristics do choose a config
87+
ChooseWithHeuristic,
88+
89+
// Actual SM100 tile configs based on user input (K-tile is 128B)
90+
CtaShape64x64x128B,
91+
CtaShape64x128x128B,
92+
CtaShape64x256x128B,
93+
CtaShape128x64x128B,
94+
CtaShape128x128x128B,
95+
CtaShape128x256x128B,
96+
CtaShape256x64x128B,
97+
CtaShape256x128x128B,
98+
CtaShape256x256x128B
99+
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
100+
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
101+
// These are already covered by the list above if general suffices.
102+
// If they need distinct enum values, they should be added.
103+
// For now, keeping the enum concise with unique shapes mentioned for general use.
104+
};
105+
106+
79107
enum class CutlassTileConfigSM90
80108
{
81109
// Signals that we should run heuristics do choose a config
@@ -132,9 +160,11 @@ struct CutlassGemmConfig
132160
WEIGHT_ONLY = 1u << 0,
133161
SIMT_ONLY = 1u << 1,
134162
INT8_ONLY = 1u << 2,
135-
HOPPER = 1u << 3,
163+
HOPPER = 1u << 3, // SM90
136164
GROUPED_GEMM = 1u << 4,
137165
FP8_ONLY = 1u << 5,
166+
BLACKWELL = 1u << 6, // SM100
167+
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
138168
};
139169

140170
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
@@ -149,45 +179,82 @@ struct CutlassGemmConfig
149179
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
150180
bool is_sm90 = false;
151181

152-
CutlassGemmConfig() {}
182+
// config options for sm100 (Blackwell)
183+
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
184+
// These might need to become SM100-specific if Blackwell introduces new concepts.
185+
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
186+
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
187+
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
188+
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
189+
bool is_sm100 = false;
190+
191+
192+
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
153193

154194
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
155195
: tile_config(tile_config)
156196
, split_k_style(split_k_style)
157197
, split_k_factor(split_k_factor)
158198
, stages(stages)
159199
, is_sm90(false)
200+
, is_sm100(false)
160201
{
161202
}
162203

163-
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
164-
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
165-
: tile_config_sm90(tile_config_sm90)
166-
, mainloop_schedule(mainloop_schedule)
167-
, epilogue_schedule(epilogue_schedule)
168-
, cluster_shape(cluster_shape)
204+
// Constructor for SM90
205+
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
206+
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
207+
: tile_config_sm90(tile_config_sm90_in)
208+
, mainloop_schedule(mainloop_schedule_in)
209+
, epilogue_schedule(epilogue_schedule_in)
210+
, cluster_shape(cluster_shape_in)
169211
, is_sm90(true)
212+
, is_sm100(false)
170213
{
171214
}
172215

216+
// Constructor for SM100 (Blackwell)
217+
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
218+
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
219+
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
220+
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
221+
: tile_config_sm100(tile_config_sm100_in)
222+
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
223+
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
224+
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
225+
, is_sm90(false) // Explicitly false
226+
, is_sm100(true)
227+
{
228+
}
229+
230+
173231
std::string toString() const
174232
{
175233
std::stringstream tactic;
176234
tactic << "Cutlass GEMM Tactic";
177-
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
235+
if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
236+
{
237+
assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
238+
tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
239+
<< "\n\ttile shape ID: " << (int) tile_config_sm100
240+
<< "\n\tcluster shape ID: " << (int) cluster_shape
241+
<< "\n\tmainloop sched: " << (int) mainloop_schedule
242+
<< "\n\tepi sched: " << (int) epilogue_schedule;
243+
}
244+
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
178245
{
179-
assert(is_sm90 && "Invalid cutlass GEMM config");
180-
tactic << "\n\tstyle=TMA"
181-
<< "\n\ttile shape ID: " << (int) tile_config_sm90
246+
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
247+
tactic << "\n\tstyle=TMA_SM90"
248+
<< "\n\ttile shape ID: " << (int) tile_config_sm90
182249
<< "\n\tcluster shape ID: " << (int) cluster_shape
183-
<< "\n\tmainloop sched: " << (int) mainloop_schedule
250+
<< "\n\tmainloop sched: " << (int) mainloop_schedule
184251
<< "\n\tepi sched: " << (int) epilogue_schedule;
185252
}
186253
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
187254
{
188-
assert(!is_sm90 && "Invalid cutlass GEMM config");
255+
assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
189256
tactic << "\n\tstyle=compatible"
190-
<< "\n\ttile shape ID: " << (int) tile_config
257+
<< "\n\ttile shape ID: " << (int) tile_config
191258
<< "\n\tstages: " << (int) stages
192259
<< "\n\tsplit_k_style: " << (int) split_k_style
193260
<< "\n\tsplit k: " << (int) split_k_factor;
@@ -204,9 +271,24 @@ struct CutlassGemmConfig
204271
std::istringstream stream(str);
205272
std::string line;
206273

274+
is_sm90 = false; // Reset flags
275+
is_sm100 = false;
276+
207277
while (std::getline(stream, line)) {
208-
if (line.find("style=TMA") != std::string::npos) {
278+
if (line.find("style=TMA_SM100") != std::string::npos) {
279+
is_sm100 = true;
280+
is_sm90 = false;
281+
std::getline(stream, line);
282+
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
283+
std::getline(stream, line);
284+
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
285+
std::getline(stream, line);
286+
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
287+
std::getline(stream, line);
288+
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
289+
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
209290
is_sm90 = true;
291+
is_sm100 = false;
210292
std::getline(stream, line);
211293
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
212294
std::getline(stream, line);
@@ -217,6 +299,7 @@ struct CutlassGemmConfig
217299
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
218300
} else if (line.find("style=compatible") != std::string::npos) {
219301
is_sm90 = false;
302+
is_sm100 = false;
220303
std::getline(stream, line);
221304
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
222305
std::getline(stream, line);
@@ -233,7 +316,14 @@ struct CutlassGemmConfig
233316
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
234317
{
235318
// clang-format off
236-
if (config.is_sm90)
319+
if (config.is_sm100)
320+
{
321+
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
322+
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
323+
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
324+
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
325+
}
326+
else if (config.is_sm90)
237327
{
238328
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
239329
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)

custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
245245
#endif
246246
}
247247

248+
// SM100 (Blackwell) candidate tile configurations
249+
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
250+
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
251+
{
252+
#ifdef FAST_BUILD
253+
return {CutlassTileConfigSM100::CtaShape128x128x128B};
254+
#else
255+
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
256+
if (config & CutlassGemmConfig::GROUPED_GEMM)
257+
{
258+
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
259+
{
260+
return {
261+
/* 1 SM (M=128) */
262+
CutlassTileConfigSM100::CtaShape128x128x128B,
263+
CutlassTileConfigSM100::CtaShape128x256x128B,
264+
/* 2 SM (M=256) */
265+
CutlassTileConfigSM100::CtaShape256x128x128B,
266+
CutlassTileConfigSM100::CtaShape256x256x128B,
267+
/* slim tiles for very tall matrices */
268+
CutlassTileConfigSM100::CtaShape128x64x128B,
269+
CutlassTileConfigSM100::CtaShape256x64x128B};
270+
}
271+
272+
/* Fp8 / Fp16 grouped-GEMM */
273+
return {
274+
CutlassTileConfigSM100::CtaShape128x128x128B,
275+
CutlassTileConfigSM100::CtaShape128x256x128B,
276+
CutlassTileConfigSM100::CtaShape256x128x128B,
277+
CutlassTileConfigSM100::CtaShape256x256x128B};
278+
}
279+
280+
/* Non-grouped path (plain GEMM or weight-only) */
281+
return {
282+
/* 1 SM tiles */
283+
CutlassTileConfigSM100::CtaShape64x64x128B,
284+
CutlassTileConfigSM100::CtaShape64x128x128B,
285+
CutlassTileConfigSM100::CtaShape64x256x128B,
286+
CutlassTileConfigSM100::CtaShape128x64x128B,
287+
CutlassTileConfigSM100::CtaShape128x128x128B,
288+
CutlassTileConfigSM100::CtaShape128x256x128B,
289+
/* 2 SM tiles */
290+
CutlassTileConfigSM100::CtaShape256x64x128B,
291+
CutlassTileConfigSM100::CtaShape256x128x128B,
292+
CutlassTileConfigSM100::CtaShape256x256x128B};
293+
#endif
294+
}
295+
296+
// M-multicast support for SM100.
297+
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
298+
{
299+
#ifdef FAST_BUILD
300+
return false;
301+
#else
302+
std::set<CutlassTileConfigSM100> m_tiles{
303+
CutlassTileConfigSM100::CtaShape128x64x128B,
304+
CutlassTileConfigSM100::CtaShape128x128x128B,
305+
CutlassTileConfigSM100::CtaShape128x256x128B,
306+
CutlassTileConfigSM100::CtaShape256x64x128B,
307+
CutlassTileConfigSM100::CtaShape256x128x128B,
308+
CutlassTileConfigSM100::CtaShape256x256x128B};
309+
return m_tiles.count(tile) == 1;
310+
#endif
311+
}
312+
313+
// N-multicast support for SM100.
314+
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
315+
{
316+
#ifdef FAST_BUILD
317+
return false;
318+
#else
319+
std::set<CutlassTileConfigSM100> n_tiles{
320+
CutlassTileConfigSM100::CtaShape64x128x128B,
321+
CutlassTileConfigSM100::CtaShape64x256x128B,
322+
CutlassTileConfigSM100::CtaShape128x128x128B,
323+
CutlassTileConfigSM100::CtaShape128x256x128B,
324+
CutlassTileConfigSM100::CtaShape256x128x128B};
325+
return n_tiles.count(tile) == 1;
326+
#endif
327+
}
328+
329+
248330
std::vector<CutlassGemmConfig> get_candidate_configs(
249331
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
250332
{
@@ -284,9 +366,50 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
284366
}
285367
return candidate_configs;
286368
}
287-
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
369+
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
370+
{
371+
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
372+
std::vector<CutlassGemmConfig> candidate_configs;
373+
374+
for (auto const& tile_config_sm100 : tiles)
375+
{
376+
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
377+
// Cluster shapes are also handled similarly.
378+
CutlassGemmConfig config(
379+
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
380+
candidate_configs.push_back(config);
288381

289-
std::vector<CutlassGemmConfig> candidate_configs;
382+
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
383+
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
384+
385+
if (has_m_mcast)
386+
{
387+
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
388+
ClusterShape::ClusterShape_2x1x1);
389+
candidate_configs.push_back(mcast_m_config);
390+
}
391+
392+
if (has_n_mcast)
393+
{
394+
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
395+
ClusterShape::ClusterShape_1x2x1);
396+
candidate_configs.push_back(mcast_n_config);
397+
}
398+
399+
if (has_m_mcast && has_n_mcast)
400+
{
401+
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
402+
ClusterShape::ClusterShape_2x2x1);
403+
candidate_configs.push_back(mcast_mn_config);
404+
}
405+
}
406+
return candidate_configs;
407+
}
408+
409+
// Fallback to older architecture configurations
410+
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
411+
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
412+
// It's fine here as it's within an else if / else block.
290413
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
291414
int const min_stages = int8_configs_only ? 3 : 2;
292415
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);

0 commit comments

Comments
 (0)