Skip to content

Commit d370196

Browse files
authored
delete delayed scaling from torchao.float8 (#1753)
Update [ghstack-poisoned]
1 parent 25ddb77 commit d370196

25 files changed

+93
-2296
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
ScalingType,
2424
)
2525
from torchao.float8.float8_linear import Float8Linear
26-
from torchao.float8.float8_linear_utils import (
27-
linear_requires_sync,
28-
sync_float8_amax_and_scale_history,
29-
)
3026
from torchao.float8.float8_tensor import ScaledMMConfig
3127

3228
# estimating TOPs for matmuls in fp32, fp16, fp8
@@ -122,39 +118,18 @@ def main(
122118
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
123119
scaling_granularity = ScalingGranularity(scaling_granularity)
124120

125-
if scaling_type_input is ScalingType.STATIC:
126-
cast_config_input = CastConfig(
127-
scaling_type=scaling_type_input,
128-
static_scale=torch.tensor([1.0], device="cuda"),
129-
scaling_granularity=scaling_granularity,
130-
)
131-
else:
132-
cast_config_input = CastConfig(
133-
scaling_type=scaling_type_input,
134-
scaling_granularity=scaling_granularity,
135-
)
136-
if scaling_type_weight is ScalingType.STATIC:
137-
cast_config_weight = CastConfig(
138-
scaling_type=scaling_type_weight,
139-
static_scale=torch.tensor([1.0], device="cuda"),
140-
scaling_granularity=scaling_granularity,
141-
)
142-
else:
143-
cast_config_weight = CastConfig(
144-
scaling_type=scaling_type_weight,
145-
scaling_granularity=scaling_granularity,
146-
)
147-
if scaling_type_grad_output is ScalingType.STATIC:
148-
cast_config_grad_output = CastConfig(
149-
scaling_type=scaling_type_grad_output,
150-
static_scale=torch.tensor([1.0], device="cuda"),
151-
scaling_granularity=scaling_granularity,
152-
)
153-
else:
154-
cast_config_grad_output = CastConfig(
155-
scaling_type=scaling_type_grad_output,
156-
scaling_granularity=scaling_granularity,
157-
)
121+
cast_config_input = CastConfig(
122+
scaling_type=scaling_type_input,
123+
scaling_granularity=scaling_granularity,
124+
)
125+
cast_config_weight = CastConfig(
126+
scaling_type=scaling_type_weight,
127+
scaling_granularity=scaling_granularity,
128+
)
129+
cast_config_grad_output = CastConfig(
130+
scaling_type=scaling_type_grad_output,
131+
scaling_granularity=scaling_granularity,
132+
)
158133

159134
config = Float8LinearConfig(
160135
cast_config_input=cast_config_input,
@@ -185,7 +160,7 @@ def main(
185160
copy.deepcopy(linear_ref),
186161
config=config,
187162
)
188-
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
163+
scaling_repr = linear_float8.extra_repr()
189164

190165
if fast_accum:
191166
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -196,8 +171,6 @@ def main(
196171
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
197172

198173
def float8_forw_backward():
199-
if linear_requires_sync(config):
200-
sync_float8_amax_and_scale_history(linear_float8)
201174
linear_float8(input_tensor).sum().backward()
202175

203176
def n_times(n, fn, *args, **kwargs):

benchmarks/float8/bench_multi_gpu.py

Lines changed: 0 additions & 180 deletions
This file was deleted.

benchmarks/float8/float8_roofline.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@
5858
)
5959

6060
from torchao.float8 import (
61-
CastConfig,
6261
Float8LinearConfig,
63-
ScalingType,
6462
convert_to_float8_training,
6563
)
6664
from torchao.float8.roofline_utils import (
@@ -219,24 +217,6 @@ def run(
219217
scaling_type_weight="dynamic",
220218
scaling_type_grad_output="dynamic",
221219
)
222-
fp8_mem_time_sympy_del_limit = get_float8_mem_sympy(
223-
M,
224-
K,
225-
N,
226-
model_torch_compile_limitations=True,
227-
scaling_type_input="delayed",
228-
scaling_type_weight="delayed",
229-
scaling_type_grad_output="delayed",
230-
)
231-
fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy(
232-
M,
233-
K,
234-
N,
235-
model_torch_compile_limitations=False,
236-
scaling_type_input="delayed",
237-
scaling_type_weight="delayed",
238-
scaling_type_grad_output="delayed",
239-
)
240220

241221
if gemm_time_strategy == "roofline":
242222
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
@@ -258,16 +238,12 @@ def run(
258238
# roofline memory overhead estimates
259239
"fp8_oh_dyn_limit",
260240
"fp8_oh_dyn_nolimit",
261-
"fp8_oh_del_limit",
262-
"fp8_oh_del_nolimit",
263241
# actual e2e measurements
264242
"bf16_s",
265243
"fp8_dyn_s",
266-
"fp8_del_s",
267244
"fp8_dyn_axs_s",
268245
# 'fp8_lw_s',
269246
"fp8_dyn_sp",
270-
"fp8_del_sp",
271247
"fp8_dyn_axs_sp",
272248
# 'fp8_lw_sp',
273249
]
@@ -309,12 +285,6 @@ def run(
309285
fp8_mem_time_dyn_nolimit_s = (
310286
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
311287
)
312-
fp8_mem_time_del_limit_s = (
313-
fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
314-
)
315-
fp8_mem_time_del_nolimit_s = (
316-
fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
317-
)
318288

319289
# create the model
320290
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
@@ -333,19 +303,6 @@ def run(
333303
m_fp8_dyn = torch.compile(m_fp8_dyn)
334304
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
335305

336-
# get the float8 delayed scaling gpu kernel time
337-
torch._dynamo.reset()
338-
config = Float8LinearConfig(
339-
enable_amax_init=False,
340-
enable_pre_and_post_forward=False,
341-
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
342-
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
343-
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
344-
)
345-
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
346-
m_fp8_del = torch.compile(m_fp8_del)
347-
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)
348-
349306
# get the float8 dynamic axiswise scaling gpu kernel time
350307
torch._dynamo.reset()
351308
config = Float8LinearConfig.from_recipe_name("rowwise")
@@ -374,16 +331,12 @@ def run(
374331
# roofline overhead estimates
375332
fp8_mem_time_dyn_limit_s,
376333
fp8_mem_time_dyn_nolimit_s,
377-
fp8_mem_time_del_limit_s,
378-
fp8_mem_time_del_nolimit_s,
379334
# e2e numbers
380335
bf16_time_actual_s,
381336
fp8_dyn_time_actual_s,
382-
fp8_del_time_actual_s,
383337
fp8_dyn_axs_time_actual_s,
384338
# fp8_lw_time_actual_s,
385339
bf16_time_actual_s / fp8_dyn_time_actual_s,
386-
bf16_time_actual_s / fp8_del_time_actual_s,
387340
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
388341
# bf16_time_actual_s / fp8_lw_time_actual_s,
389342
]

0 commit comments

Comments
 (0)