Skip to content

Commit cd69415

Browse files
authored
roofline estimation: delete axiswise scaling, for now (#1782)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent e6706ca commit cd69415

File tree

1 file changed

+7
-52
lines changed

1 file changed

+7
-52
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,12 @@
5858
)
5959

6060
from torchao.float8 import (
61-
Float8LinearConfig,
6261
convert_to_float8_training,
6362
)
6463
from torchao.testing.float8.roofline_utils import (
6564
get_float8_mem_sympy,
6665
get_gemm_time_sympy,
6766
)
68-
from torchao.utils import is_sm_at_least_90, is_sm_at_least_100
6967

7068

7169
class LNLinearSigmoid(torch.nn.Module):
@@ -155,21 +153,13 @@ def do_matmul(A, B):
155153

156154
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
157155

158-
if is_sm_at_least_90() and (not is_sm_at_least_100()):
159-
scale_a = torch.ones(M, 1, device=device)
160-
scale_b = torch.ones(1, N, device=device)
161-
fast_accum = True # for axiswise
162-
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
163-
else:
164-
f8_axs_time_s = -1.0
165-
166156
# save to cache if needed
167157
if cache_filename is not None:
168-
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
158+
cache[key] = [bf16_time_s, f8_time_s]
169159
with open(cache_filename, "w") as f:
170160
json.dump(cache, f)
171161

172-
return bf16_time_s, f8_time_s, f8_axs_time_s
162+
return bf16_time_s, f8_time_s
173163

174164

175165
def run(
@@ -229,18 +219,13 @@ def run(
229219
# gemm microbenchmarks
230220
"bf16_gemm_s",
231221
"fp8_gemm_s",
232-
"fp8_axs_gemm_time_s",
233222
# roofline memory overhead estimates
234-
"fp8_oh_dyn_limit",
235-
"fp8_oh_dyn_nolimit",
223+
"fp8_oh_estimated",
224+
"fp8_oh_ideal",
236225
# actual e2e measurements
237226
"bf16_s",
238227
"fp8_dyn_s",
239-
"fp8_dyn_axs_s",
240-
# 'fp8_lw_s',
241228
"fp8_dyn_sp",
242-
"fp8_dyn_axs_sp",
243-
# 'fp8_lw_sp',
244229
]
245230
results = []
246231

@@ -251,18 +236,17 @@ def run(
251236
break
252237

253238
if gemm_time_strategy == "benchmarks":
254-
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(
239+
bf16_g1, f8_g1 = get_gemm_times(
255240
M_val, K_val, N_val, True, gemm_cache_filename
256241
)
257-
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(
242+
bf16_g2, f8_g2 = get_gemm_times(
258243
M_val, N_val, K_val, False, gemm_cache_filename
259244
)
260-
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(
245+
bf16_g3, f8_g3 = get_gemm_times(
261246
K_val, M_val, N_val, False, gemm_cache_filename
262247
)
263248
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
264249
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
265-
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
266250
else:
267251
assert gemm_time_strategy == "roofline", "unsupported"
268252
bf16_time_val = (
@@ -271,8 +255,6 @@ def run(
271255
fp8_gemm_time_s = (
272256
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
273257
)
274-
# for now, assume axiswise gemm is similar to tensorwise
275-
fp8_axs_gemm_time_s = fp8_gemm_time_s
276258

277259
fp8_mem_time_dyn_limit_s = (
278260
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
@@ -299,28 +281,6 @@ def run(
299281
m_fp8_dyn = torch.compile(m_fp8_dyn)
300282
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
301283

302-
# get the float8 dynamic axiswise scaling gpu kernel time, if supported
303-
# on current hardware
304-
if is_sm_at_least_90() and (not is_sm_at_least_100()):
305-
torch._dynamo.reset()
306-
config = Float8LinearConfig.from_recipe_name("rowwise")
307-
m_fp8_dyn_axs = convert_to_float8_training(
308-
copy.deepcopy(m_orig), config=config
309-
)
310-
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
311-
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
312-
else:
313-
fp8_dyn_axs_time_actual_s = -1.0
314-
315-
# get the lw recipe scaling gpu kernel time
316-
# TODO(future PR): enable below once basic performance issues
317-
# are fixed
318-
# torch._dynamo.reset()
319-
# config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp")
320-
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
321-
# m_fp8_lw = torch.compile(m_fp8_lw)
322-
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
323-
324284
results.append(
325285
[
326286
M_val,
@@ -329,18 +289,13 @@ def run(
329289
# gemm microbenchmarks
330290
bf16_time_val,
331291
fp8_gemm_time_s,
332-
fp8_axs_gemm_time_s,
333292
# roofline overhead estimates
334293
fp8_mem_time_dyn_limit_s,
335294
fp8_mem_time_dyn_nolimit_s,
336295
# e2e numbers
337296
bf16_time_actual_s,
338297
fp8_dyn_time_actual_s,
339-
fp8_dyn_axs_time_actual_s,
340-
# fp8_lw_time_actual_s,
341298
bf16_time_actual_s / fp8_dyn_time_actual_s,
342-
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
343-
# bf16_time_actual_s / fp8_lw_time_actual_s,
344299
]
345300
)
346301

0 commit comments

Comments
 (0)