File tree Expand file tree Collapse file tree 2 files changed +31
-10
lines changed
src/compressed_tensors/utils Expand file tree Collapse file tree 2 files changed +31
-10
lines changed Original file line number Diff line number Diff line change @@ -171,22 +171,24 @@ def update_parameter_data(
171
171
172
172
def get_execution_device (module : torch .nn .Module ) -> torch .device :
173
173
"""
174
- Get the device which inputs should be moved to before module execution
174
+ Get the device which inputs should be moved to before module execution.
175
+ Assume that modules execute in the same order as returned by `model.modules()`
175
176
176
177
:param module: module to check, may be offloaded
177
178
:return: onload device of module
178
179
"""
179
- if has_offloaded_params (module ):
180
- return module ._hf_hook .execution_device
180
+ for module in module .modules ():
181
+ if has_offloaded_params (module ):
182
+ return module ._hf_hook .execution_device
181
183
182
- first_param = next (module .parameters (), None )
183
- if first_param is None :
184
- warnings .warn (
185
- f"Unable able to infer execution device of { module } , falling back to CPU"
186
- )
187
- return torch .device ("cpu" )
184
+ param = next (module .parameters (recurse = False ), None )
185
+ if param is not None :
186
+ return param .device
188
187
189
- return first_param .device
188
+ warnings .warn (
189
+ f"Unable able to get execution device of { module } , falling back to CPU"
190
+ )
191
+ return torch .device ("cpu" )
190
192
191
193
192
194
def register_offload_parameter (
Original file line number Diff line number Diff line change @@ -102,6 +102,25 @@ def test_get_execution_device():
102
102
assert get_execution_device (module ) == torch .device ("cuda:0" )
103
103
104
104
105
+ @requires_gpu
106
+ @requires_accelerate ()
107
+ def test_get_execution_device_model ():
108
+ class Model (torch .nn .Module ):
109
+ def __init__ (self ):
110
+ super ().__init__ ()
111
+ self .a = torch .nn .Linear (1 , 2 )
112
+ self .b = torch .nn .Linear (2 , 2 , device = "cuda:0" )
113
+
114
+ def forward (self , x ):
115
+ return self .b (self .a (x ).to ("cuda:0" ))
116
+
117
+ model = Model ()
118
+ assert get_execution_device (model ) == torch .device ("cpu" )
119
+
120
+ offloaded_dispatch (model .a , torch .device ("cuda:0" ))
121
+ assert get_execution_device (model ) == torch .device ("cuda:0" )
122
+
123
+
105
124
@requires_accelerate ()
106
125
def test_register_offload_parameter ():
107
126
from accelerate import init_empty_weights
You can’t perform that action at this time.
0 commit comments