@@ -100,31 +100,37 @@ def test_update_offload_parameter():
100
100
from accelerate .hooks import attach_align_device_hook
101
101
102
102
module = ExampleModule ()
103
- param_a = torch .nn . Parameter ( torch . tensor (1.0 ) )
104
- param_b = torch .nn . Parameter ( torch . tensor (2.0 ) )
103
+ tensor_a = torch .tensor (1.0 )
104
+ tensor_b = torch .tensor (2.0 )
105
105
106
106
# can update modules which are not offloaded
107
- update_offload_parameter (module , "a" , param_a )
108
- assert module .a == param_a
107
+ update_offload_parameter (module , "a" , tensor_a )
108
+ assert module .a == tensor_a
109
109
110
110
# can update modules which are offloaded
111
111
attach_align_device_hook (module , offload = True , weights_map = module .state_dict ())
112
- update_offload_parameter (module , "b" , param_b )
112
+ update_offload_parameter (module , "b" , tensor_b )
113
113
assert module .b .device == torch .device ("meta" )
114
- assert module ._hf_hook .weights_map ["b" ] == param_b . data
114
+ assert module ._hf_hook .weights_map ["b" ] == tensor_b
115
115
116
116
# data persists across onloading
117
117
with align_module_device (module , execution_device = "cpu" ):
118
- assert module .a == param_a
119
- assert module .b == param_b
120
- assert module ._hf_hook .weights_map ["a" ] == param_a . data
121
- assert module ._hf_hook .weights_map ["b" ] == param_b . data
118
+ assert module .a . data == tensor_a
119
+ assert module .b . data == tensor_b
120
+ assert module ._hf_hook .weights_map ["a" ] == tensor_a
121
+ assert module ._hf_hook .weights_map ["b" ] == tensor_b
122
122
123
123
# data persists across offloading
124
124
assert module .a .device == torch .device ("meta" )
125
125
assert module .b .device == torch .device ("meta" )
126
- assert module ._hf_hook .weights_map ["a" ] == param_a .data
127
- assert module ._hf_hook .weights_map ["b" ] == param_b .data
126
+ assert module ._hf_hook .weights_map ["a" ] == tensor_a
127
+ assert module ._hf_hook .weights_map ["b" ] == tensor_b
128
+
129
+ # can update with differnt shape with warning
130
+ with pytest .warns ():
131
+ new_data = torch .tensor ([3.0 ])
132
+ update_offload_parameter (module , "a" , new_data )
133
+ assert module ._hf_hook .weights_map ["a" ] == new_data
128
134
129
135
130
136
@requires_accelerate ()
0 commit comments