Skip to content

Commit 890608d

Browse files
authored
Accelerate Utilities: Throw warning when updating with different shapes (#231)
* replace not copy Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Add warning when update shapes are different Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 038f960 commit 890608d

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""
2727

2828
import contextlib
29+
import warnings
2930
from functools import wraps
3031
from typing import Any, Callable, Dict, Literal, Optional, Union
3132

@@ -200,9 +201,14 @@ def update_offload_parameter(
200201
"""
201202
param = getattr(module, name)
202203
data = data.to(param.dtype)
204+
if param.data.shape != data.shape:
205+
warnings.warn(
206+
f"Shape of parameter being updated {param.data.shape} does not match shape "
207+
f"of update data {data.shape}"
208+
)
203209

204210
# copy data into onloaded parameter if applicable
205-
if param.device != "meta":
211+
if param.device != torch.device("meta"):
206212
param.data.copy_(data)
207213

208214
# update offload dict

tests/test_utils/test_offload.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,31 +100,37 @@ def test_update_offload_parameter():
100100
from accelerate.hooks import attach_align_device_hook
101101

102102
module = ExampleModule()
103-
param_a = torch.nn.Parameter(torch.tensor(1.0))
104-
param_b = torch.nn.Parameter(torch.tensor(2.0))
103+
tensor_a = torch.tensor(1.0)
104+
tensor_b = torch.tensor(2.0)
105105

106106
# can update modules which are not offloaded
107-
update_offload_parameter(module, "a", param_a)
108-
assert module.a == param_a
107+
update_offload_parameter(module, "a", tensor_a)
108+
assert module.a == tensor_a
109109

110110
# can update modules which are offloaded
111111
attach_align_device_hook(module, offload=True, weights_map=module.state_dict())
112-
update_offload_parameter(module, "b", param_b)
112+
update_offload_parameter(module, "b", tensor_b)
113113
assert module.b.device == torch.device("meta")
114-
assert module._hf_hook.weights_map["b"] == param_b.data
114+
assert module._hf_hook.weights_map["b"] == tensor_b
115115

116116
# data persists across onloading
117117
with align_module_device(module, execution_device="cpu"):
118-
assert module.a == param_a
119-
assert module.b == param_b
120-
assert module._hf_hook.weights_map["a"] == param_a.data
121-
assert module._hf_hook.weights_map["b"] == param_b.data
118+
assert module.a.data == tensor_a
119+
assert module.b.data == tensor_b
120+
assert module._hf_hook.weights_map["a"] == tensor_a
121+
assert module._hf_hook.weights_map["b"] == tensor_b
122122

123123
# data persists across offloading
124124
assert module.a.device == torch.device("meta")
125125
assert module.b.device == torch.device("meta")
126-
assert module._hf_hook.weights_map["a"] == param_a.data
127-
assert module._hf_hook.weights_map["b"] == param_b.data
126+
assert module._hf_hook.weights_map["a"] == tensor_a
127+
assert module._hf_hook.weights_map["b"] == tensor_b
128+
129+
# can update with differnt shape with warning
130+
with pytest.warns():
131+
new_data = torch.tensor([3.0])
132+
update_offload_parameter(module, "a", new_data)
133+
assert module._hf_hook.weights_map["a"] == new_data
128134

129135

130136
@requires_accelerate()

0 commit comments

Comments
 (0)