9
9
BYTES_PER_EL_FLOAT8 = 1
10
10
BYTES_PER_EL_BF16 = 2
11
11
12
- # https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
13
- H100_BF16_PEAK_TOPS = 989e12
14
- H100_FP8_PEAK_TOPS = 1979e12
12
+ gpu_name_to_specs = {
13
+ "NVIDIA H100" : {
14
+ # https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
15
+ "bf16_peak_tops" : 989e12 ,
16
+ "fp8_peak_tops" : 1979e12 ,
17
+ # 2.4 TB per second, custom to Meta's H100 variant
18
+ "peak_mem_bw_bytes_sec" : 2.4e12 ,
19
+ # based on quick experimental observation with sample large inputs
20
+ "pct_achievable_gemm_tops" : 0.6 ,
21
+ # based on previous experience looking at pointwise triton kernels with large inputs,
22
+ # which would hit about 2.2k GBPS on Meta's H100 variant
23
+ "pct_achievable_mem_bw" : 0.92 ,
24
+ },
25
+ "NVIDIA B200" : {
26
+ # https://resources.nvidia.com/en-us-blackwell-architecture, page 19,
27
+ # divide by 2 because no sparsity
28
+ "bf16_peak_tops" : 2.25e15 ,
29
+ "fp8_peak_tops" : 4.5e15 ,
30
+ "fp4_peak_tops" : 9.0e15 ,
31
+ # https://resources.nvidia.com/en-us-blackwell-architecture, page 20
32
+ # 8.0 TB per second
33
+ "peak_mem_bw_bytes_sec" : 8.0e12 ,
34
+ # for now, copy over from H100
35
+ # TODO(future): measure once we have the hardware
36
+ "pct_achievable_gemm_tops" : 0.6 ,
37
+ # for now, copy over from H100
38
+ # TODO(future): measure once we have the hardware
39
+ "pct_achievable_mem_bw" : 0.92 ,
40
+ },
41
+ # TODO(future): more GPU names
42
+ }
43
+
44
+
45
+ def get_specs ():
46
+ gpu_name = torch .cuda .get_device_name (0 )
47
+ return gpu_name_to_specs [gpu_name ]
15
48
16
- # 2.4 TB per second, custom to Meta's H100 variant
17
- H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12
18
-
19
- # based on quick experimental observation with sample large inputs
20
- H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6
21
-
22
- # based on previous experience looking at pointwise triton kernels with large inputs,
23
- # which would hit about 2.2k GBPS on Meta's H100 variant
24
- H100_PCT_ACHIEVABLE_MEM_BW = 0.92
25
49
26
50
# Source: run a triton kernel with a single element read/write on an H100 and
27
51
# measure GPU time from the trace
@@ -65,12 +89,13 @@ def get_tensor_memory_traffic_bytes(
65
89
66
90
67
91
def get_gemm_time_sympy (M , K , N , dtype ):
92
+ specs = get_specs ()
68
93
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
69
94
if dtype is torch .bfloat16 :
70
- peak_tops = H100_BF16_PEAK_TOPS
95
+ peak_tops = specs [ "bf16_peak_tops" ]
71
96
elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
72
- peak_tops = H100_FP8_PEAK_TOPS
73
- gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
97
+ peak_tops = specs [ "fp8_peak_tops" ]
98
+ gemm_time_s = gemm_ops / peak_tops / specs [ "pct_achievable_gemm_tops" ]
74
99
return gemm_time_s
75
100
76
101
@@ -87,6 +112,8 @@ def get_float8_mem_sympy(
87
112
assert scaling_type_weight in ("dynamic" ,), "unsupported"
88
113
assert scaling_type_grad_output in ("dynamic" ,), "unsupported"
89
114
115
+ specs = get_specs ()
116
+
90
117
# there are three gemms in the fwd/bwd of a linear:
91
118
#
92
119
# input @ weight_t = output
@@ -148,7 +175,7 @@ def get_float8_mem_sympy(
148
175
)
149
176
fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
150
177
fp8_mem_time_s = (
151
- fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
178
+ fp8_total_mem / specs [ "peak_mem_bw_bytes_sec" ] / specs [ "pct_achievable_mem_bw" ]
152
179
)
153
180
154
181
# Adjust final estimate for small kernel launches
0 commit comments