40
40
"pack_bitmasks" ,
41
41
"unpack_bitmasks" ,
42
42
"remove_suffix" ,
43
- "module_replace_dfs " ,
43
+ "module_map_replace " ,
44
44
]
45
45
46
46
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -339,12 +339,26 @@ def remove_suffix(value: str, suffix: str) -> str:
339
339
return value [: - len (suffix )]
340
340
341
341
342
- def module_replace_dfs (
342
+ def module_map_replace (
343
343
module : torch .nn .Module ,
344
344
func : Callable [[torch .nn .Module ], torch .nn .Module ],
345
- pre : bool = True ,
346
345
progress : Union [bool , tqdm .tqdm ] = False ,
346
+ pre : bool = True ,
347
347
) -> torch .nn .Module :
348
+ """
349
+ Replaces modules in a given `torch.nn.Module` recursively using a provided function.
350
+
351
+ This function traverses the module hierarchy and applies the `func` transformation
352
+ either before (`pre=True`) or after (`pre=False`) recursing into children modules.
353
+ Optionally displays progress using tqdm.
354
+
355
+ :param module: root module to replace
356
+ :param func: module mapping function
357
+ :param progress: if True, display a tqdm progress bar.
358
+ If a `tqdm.tqdm` instance is provided, the instance will be updated
359
+ :param pre: if True, apply with pre-order, post-order otherwise
360
+ :return: the modified module after applying the function to all submodules
361
+ """
348
362
if progress is True :
349
363
total = len (list (module .modules ()))
350
364
progress = tqdm .tqdm (total = total )
@@ -353,7 +367,7 @@ def module_replace_dfs(
353
367
module = func (module )
354
368
355
369
for name , child in list (module .named_children ()):
356
- module .add_module (name , module_replace_dfs (child , func , pre , progress ))
370
+ module .add_module (name , module_map_replace (child , func , pre , progress ))
357
371
358
372
if not pre :
359
373
module = func (module )
0 commit comments