@@ -61,9 +61,10 @@ def to(
61
61
self ,
62
62
device : Optional [torch .device ] = None ,
63
63
dtype : Optional [torch .dtype ] = None ,
64
+ non_blocking : bool = False ,
64
65
) -> None :
65
66
if self .bias is not None :
66
- self .bias = self .bias .to (device = device , dtype = dtype )
67
+ self .bias = self .bias .to (device = device , dtype = dtype , non_blocking = non_blocking )
67
68
68
69
69
70
# TODO: find and debug lora/locon with bias
@@ -109,14 +110,15 @@ def to(
109
110
self ,
110
111
device : Optional [torch .device ] = None ,
111
112
dtype : Optional [torch .dtype ] = None ,
113
+ non_blocking : bool = False ,
112
114
) -> None :
113
- super ().to (device = device , dtype = dtype )
115
+ super ().to (device = device , dtype = dtype , non_blocking = non_blocking )
114
116
115
- self .up = self .up .to (device = device , dtype = dtype )
116
- self .down = self .down .to (device = device , dtype = dtype )
117
+ self .up = self .up .to (device = device , dtype = dtype , non_blocking = non_blocking )
118
+ self .down = self .down .to (device = device , dtype = dtype , non_blocking = non_blocking )
117
119
118
120
if self .mid is not None :
119
- self .mid = self .mid .to (device = device , dtype = dtype )
121
+ self .mid = self .mid .to (device = device , dtype = dtype , non_blocking = non_blocking )
120
122
121
123
122
124
class LoHALayer (LoRALayerBase ):
@@ -169,18 +171,19 @@ def to(
169
171
self ,
170
172
device : Optional [torch .device ] = None ,
171
173
dtype : Optional [torch .dtype ] = None ,
174
+ non_blocking : bool = False ,
172
175
) -> None :
173
176
super ().to (device = device , dtype = dtype )
174
177
175
- self .w1_a = self .w1_a .to (device = device , dtype = dtype )
176
- self .w1_b = self .w1_b .to (device = device , dtype = dtype )
178
+ self .w1_a = self .w1_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
179
+ self .w1_b = self .w1_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
177
180
if self .t1 is not None :
178
- self .t1 = self .t1 .to (device = device , dtype = dtype )
181
+ self .t1 = self .t1 .to (device = device , dtype = dtype , non_blocking = non_blocking )
179
182
180
- self .w2_a = self .w2_a .to (device = device , dtype = dtype )
181
- self .w2_b = self .w2_b .to (device = device , dtype = dtype )
183
+ self .w2_a = self .w2_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
184
+ self .w2_b = self .w2_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
182
185
if self .t2 is not None :
183
- self .t2 = self .t2 .to (device = device , dtype = dtype )
186
+ self .t2 = self .t2 .to (device = device , dtype = dtype , non_blocking = non_blocking )
184
187
185
188
186
189
class LoKRLayer (LoRALayerBase ):
@@ -265,6 +268,7 @@ def to(
265
268
self ,
266
269
device : Optional [torch .device ] = None ,
267
270
dtype : Optional [torch .dtype ] = None ,
271
+ non_blocking : bool = False ,
268
272
) -> None :
269
273
super ().to (device = device , dtype = dtype )
270
274
@@ -273,19 +277,19 @@ def to(
273
277
else :
274
278
assert self .w1_a is not None
275
279
assert self .w1_b is not None
276
- self .w1_a = self .w1_a .to (device = device , dtype = dtype )
277
- self .w1_b = self .w1_b .to (device = device , dtype = dtype )
280
+ self .w1_a = self .w1_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
281
+ self .w1_b = self .w1_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
278
282
279
283
if self .w2 is not None :
280
- self .w2 = self .w2 .to (device = device , dtype = dtype )
284
+ self .w2 = self .w2 .to (device = device , dtype = dtype , non_blocking = non_blocking )
281
285
else :
282
286
assert self .w2_a is not None
283
287
assert self .w2_b is not None
284
- self .w2_a = self .w2_a .to (device = device , dtype = dtype )
285
- self .w2_b = self .w2_b .to (device = device , dtype = dtype )
288
+ self .w2_a = self .w2_a .to (device = device , dtype = dtype , non_blocking = non_blocking )
289
+ self .w2_b = self .w2_b .to (device = device , dtype = dtype , non_blocking = non_blocking )
286
290
287
291
if self .t2 is not None :
288
- self .t2 = self .t2 .to (device = device , dtype = dtype )
292
+ self .t2 = self .t2 .to (device = device , dtype = dtype , non_blocking = non_blocking )
289
293
290
294
291
295
class FullLayer (LoRALayerBase ):
@@ -319,10 +323,11 @@ def to(
319
323
self ,
320
324
device : Optional [torch .device ] = None ,
321
325
dtype : Optional [torch .dtype ] = None ,
326
+ non_blocking : bool = False ,
322
327
) -> None :
323
328
super ().to (device = device , dtype = dtype )
324
329
325
- self .weight = self .weight .to (device = device , dtype = dtype )
330
+ self .weight = self .weight .to (device = device , dtype = dtype , non_blocking = non_blocking )
326
331
327
332
328
333
class IA3Layer (LoRALayerBase ):
@@ -358,11 +363,12 @@ def to(
358
363
self ,
359
364
device : Optional [torch .device ] = None ,
360
365
dtype : Optional [torch .dtype ] = None ,
366
+ non_blocking : bool = False ,
361
367
):
362
368
super ().to (device = device , dtype = dtype )
363
369
364
- self .weight = self .weight .to (device = device , dtype = dtype )
365
- self .on_input = self .on_input .to (device = device , dtype = dtype )
370
+ self .weight = self .weight .to (device = device , dtype = dtype , non_blocking = non_blocking )
371
+ self .on_input = self .on_input .to (device = device , dtype = dtype , non_blocking = non_blocking )
366
372
367
373
368
374
AnyLoRALayer = Union [LoRALayer , LoHALayer , LoKRLayer , FullLayer , IA3Layer ]
@@ -388,10 +394,11 @@ def to(
388
394
self ,
389
395
device : Optional [torch .device ] = None ,
390
396
dtype : Optional [torch .dtype ] = None ,
397
+ non_blocking : bool = False ,
391
398
) -> None :
392
399
# TODO: try revert if exception?
393
400
for _key , layer in self .layers .items ():
394
- layer .to (device = device , dtype = dtype )
401
+ layer .to (device = device , dtype = dtype , non_blocking = non_blocking )
395
402
396
403
def calc_size (self ) -> int :
397
404
model_size = 0
@@ -514,7 +521,7 @@ def from_checkpoint(
514
521
# lower memory consumption by removing already parsed layer values
515
522
state_dict [layer_key ].clear ()
516
523
517
- layer .to (device = device , dtype = dtype )
524
+ layer .to (device = device , dtype = dtype , non_blocking = True )
518
525
model .layers [layer_key ] = layer
519
526
520
527
return model
0 commit comments