Skip to content

Commit 7477534

Browse files
mgoinkylesayrs
andauthored
Fix _initialize_scale_zero_point initializing on the wrong device (#295)
* Fix `_initialize_scale_zero_point` initializing on the wrong device * update comment Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use util Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * style Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4438d08 commit 7477534

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
3232
from compressed_tensors.utils import (
3333
disable_hf_hook,
34-
has_offloaded_params,
34+
get_execution_device,
3535
register_offload_parameter,
3636
)
3737
from torch.nn import Module, Parameter
@@ -148,11 +148,8 @@ def _initialize_scale_zero_point(
148148
if quantization_args.dynamic:
149149
return
150150

151-
# begin on the same device as other parameters or cpu if offloaded.
152-
# in the offloaded case, there's no point moving tensors to the execution device
153-
# if they're going to be immediately offloaded by `register_offload_parameter`
154-
params_device = next(module.parameters()).device
155-
device = "cpu" if has_offloaded_params(module) else params_device
151+
# initialize on execution device to avoid performing quantized ops on cpu
152+
device = get_execution_device(module)
156153

157154
# infer expected scale/zero point shape
158155
if quantization_args.strategy == QuantizationStrategy.TOKEN:

0 commit comments

Comments
 (0)