12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from contextlib import nullcontext
15
+ from contextlib import contextmanager , nullcontext
16
16
from typing import Dict , List , Optional , Set , Tuple
17
17
18
18
import torch
@@ -56,23 +56,58 @@ def __init__(
56
56
buffers : Optional [List [torch .Tensor ]] = None ,
57
57
non_blocking : bool = False ,
58
58
stream : Optional [torch .cuda .Stream ] = None ,
59
- cpu_param_dict : Optional [ Dict [ torch . nn . Parameter , torch . Tensor ]] = None ,
59
+ low_cpu_mem_usage = False ,
60
60
onload_self : bool = True ,
61
61
) -> None :
62
62
self .modules = modules
63
63
self .offload_device = offload_device
64
64
self .onload_device = onload_device
65
65
self .offload_leader = offload_leader
66
66
self .onload_leader = onload_leader
67
- self .parameters = parameters
68
- self .buffers = buffers
67
+ self .parameters = parameters or []
68
+ self .buffers = buffers or []
69
69
self .non_blocking = non_blocking or stream is not None
70
70
self .stream = stream
71
- self .cpu_param_dict = cpu_param_dict
72
71
self .onload_self = onload_self
72
+ self .low_cpu_mem_usage = low_cpu_mem_usage
73
73
74
- if self .stream is not None and self .cpu_param_dict is None :
75
- raise ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
74
+ self .cpu_param_dict = self ._init_cpu_param_dict ()
75
+
76
+ def _init_cpu_param_dict (self ):
77
+ cpu_param_dict = {}
78
+ if self .stream is None :
79
+ return cpu_param_dict
80
+
81
+ for module in self .modules :
82
+ for param in module .parameters ():
83
+ cpu_param_dict [param ] = param .data .cpu () if self .low_cpu_mem_usage else param .data .cpu ().pin_memory ()
84
+ for buffer in module .buffers ():
85
+ cpu_param_dict [buffer ] = (
86
+ buffer .data .cpu () if self .low_cpu_mem_usage else buffer .data .cpu ().pin_memory ()
87
+ )
88
+
89
+ for param in self .parameters :
90
+ cpu_param_dict [param ] = param .data .cpu () if self .low_cpu_mem_usage else param .data .cpu ().pin_memory ()
91
+
92
+ for buffer in self .buffers :
93
+ cpu_param_dict [buffer ] = buffer .data .cpu () if self .low_cpu_mem_usage else buffer .data .cpu ().pin_memory ()
94
+
95
+ return cpu_param_dict
96
+
97
+ @contextmanager
98
+ def _pinned_memory_tensors (self ):
99
+ pinned_dict = {}
100
+ try :
101
+ for param , tensor in self .cpu_param_dict .items ():
102
+ if not tensor .is_pinned ():
103
+ pinned_dict [param ] = tensor .pin_memory ()
104
+ else :
105
+ pinned_dict [param ] = tensor
106
+
107
+ yield pinned_dict
108
+
109
+ finally :
110
+ pinned_dict = None
76
111
77
112
def onload_ (self ):
78
113
r"""Onloads the group of modules to the onload_device."""
@@ -82,15 +117,30 @@ def onload_(self):
82
117
self .stream .synchronize ()
83
118
84
119
with context :
85
- for group_module in self .modules :
86
- for param in group_module .parameters ():
87
- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
88
- for buffer in group_module .buffers ():
89
- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
90
- if self .parameters is not None :
120
+ if self .stream is not None :
121
+ with self ._pinned_memory_tensors () as pinned_memory :
122
+ for group_module in self .modules :
123
+ for param in group_module .parameters ():
124
+ param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
125
+ for buffer in group_module .buffers ():
126
+ buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
127
+
128
+ for param in self .parameters :
129
+ param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
130
+
131
+ for buffer in self .buffers :
132
+ buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
133
+
134
+ else :
135
+ for group_module in self .modules :
136
+ for param in group_module .parameters ():
137
+ param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
138
+ for buffer in group_module .buffers ():
139
+ buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
140
+
91
141
for param in self .parameters :
92
142
param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
93
- if self . buffers is not None :
143
+
94
144
for buffer in self .buffers :
95
145
buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
96
146
@@ -101,21 +151,18 @@ def offload_(self):
101
151
for group_module in self .modules :
102
152
for param in group_module .parameters ():
103
153
param .data = self .cpu_param_dict [param ]
104
- if self .parameters is not None :
105
- for param in self .parameters :
106
- param .data = self .cpu_param_dict [param ]
107
- if self .buffers is not None :
108
- for buffer in self .buffers :
109
- buffer .data = self .cpu_param_dict [buffer ]
154
+ for param in self .parameters :
155
+ param .data = self .cpu_param_dict [param ]
156
+ for buffer in self .buffers :
157
+ buffer .data = self .cpu_param_dict [buffer ]
158
+
110
159
else :
111
160
for group_module in self .modules :
112
161
group_module .to (self .offload_device , non_blocking = self .non_blocking )
113
- if self .parameters is not None :
114
- for param in self .parameters :
115
- param .data = param .data .to (self .offload_device , non_blocking = self .non_blocking )
116
- if self .buffers is not None :
117
- for buffer in self .buffers :
118
- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
162
+ for param in self .parameters :
163
+ param .data = param .data .to (self .offload_device , non_blocking = self .non_blocking )
164
+ for buffer in self .buffers :
165
+ buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
119
166
120
167
121
168
class GroupOffloadingHook (ModelHook ):
@@ -284,6 +331,7 @@ def apply_group_offloading(
284
331
num_blocks_per_group : Optional [int ] = None ,
285
332
non_blocking : bool = False ,
286
333
use_stream : bool = False ,
334
+ low_cpu_mem_usage = False ,
287
335
) -> None :
288
336
r"""
289
337
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +413,12 @@ def apply_group_offloading(
365
413
raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
366
414
367
415
_apply_group_offloading_block_level (
368
- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream
416
+ module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
369
417
)
370
418
elif offload_type == "leaf_level" :
371
- _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
419
+ _apply_group_offloading_leaf_level (
420
+ module , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
421
+ )
372
422
else :
373
423
raise ValueError (f"Unsupported offload_type: { offload_type } " )
374
424
@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
380
430
onload_device : torch .device ,
381
431
non_blocking : bool ,
382
432
stream : Optional [torch .cuda .Stream ] = None ,
433
+ low_cpu_mem_usage : bool = False ,
383
434
) -> None :
384
435
r"""
385
436
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
400
451
for overlapping computation and data transfer.
401
452
"""
402
453
403
- # Create a pinned CPU parameter dict for async data transfer if streams are to be used
404
- cpu_param_dict = None
405
- if stream is not None :
406
- cpu_param_dict = _get_pinned_cpu_param_dict (module )
407
-
408
454
# Create module groups for ModuleList and Sequential blocks
409
455
modules_with_group_offloading = set ()
410
456
unmatched_modules = []
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
425
471
onload_leader = current_modules [0 ],
426
472
non_blocking = non_blocking ,
427
473
stream = stream ,
428
- cpu_param_dict = cpu_param_dict ,
474
+ low_cpu_mem_usage = low_cpu_mem_usage ,
429
475
onload_self = stream is None ,
430
476
)
431
477
matched_module_groups .append (group )
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
462
508
buffers = buffers ,
463
509
non_blocking = False ,
464
510
stream = None ,
465
- cpu_param_dict = None ,
466
511
onload_self = True ,
467
512
)
468
513
next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
475
520
onload_device : torch .device ,
476
521
non_blocking : bool ,
477
522
stream : Optional [torch .cuda .Stream ] = None ,
523
+ low_cpu_mem_usage : bool = False ,
478
524
) -> None :
479
525
r"""
480
526
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
497
543
for overlapping computation and data transfer.
498
544
"""
499
545
500
- # Create a pinned CPU parameter dict for async data transfer if streams are to be used
501
- cpu_param_dict = None
502
- if stream is not None :
503
- cpu_param_dict = _get_pinned_cpu_param_dict (module )
504
-
505
546
# Create module groups for leaf modules and apply group offloading hooks
506
547
modules_with_group_offloading = set ()
507
548
for name , submodule in module .named_modules ():
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
515
556
onload_leader = submodule ,
516
557
non_blocking = non_blocking ,
517
558
stream = stream ,
518
- cpu_param_dict = cpu_param_dict ,
559
+ low_cpu_mem_usage = low_cpu_mem_usage ,
519
560
onload_self = True ,
520
561
)
521
562
_apply_group_offloading_hook (submodule , group , None )
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
560
601
buffers = buffers ,
561
602
non_blocking = non_blocking ,
562
603
stream = stream ,
563
- cpu_param_dict = cpu_param_dict ,
604
+ low_cpu_mem_usage = low_cpu_mem_usage ,
564
605
onload_self = True ,
565
606
)
566
607
_apply_group_offloading_hook (parent_module , group , None )
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
579
620
buffers = None ,
580
621
non_blocking = False ,
581
622
stream = None ,
582
- cpu_param_dict = None ,
623
+ low_cpu_mem_usage = low_cpu_mem_usage ,
583
624
onload_self = True ,
584
625
)
585
626
_apply_lazy_group_offloading_hook (module , unmatched_group , None )
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
616
657
registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
617
658
618
659
619
- def _get_pinned_cpu_param_dict (module : torch .nn .Module ) -> Dict [torch .nn .Parameter , torch .Tensor ]:
620
- cpu_param_dict = {}
621
- for param in module .parameters ():
622
- param .data = param .data .cpu ().pin_memory ()
623
- cpu_param_dict [param ] = param .data
624
- for buffer in module .buffers ():
625
- buffer .data = buffer .data .cpu ().pin_memory ()
626
- cpu_param_dict [buffer ] = buffer .data
627
- return cpu_param_dict
628
-
629
-
630
660
def _gather_parameters_with_no_group_offloading_parent (
631
661
module : torch .nn .Module , modules_with_group_offloading : Set [str ]
632
662
) -> List [torch .nn .Parameter ]:
0 commit comments