Skip to content

Commit 43736a9

Browse files
committed
rename to module_map_replace
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent d4e96d1 commit 43736a9

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
get_safetensors_folder,
6363
has_offloaded_params,
6464
merge_names,
65-
module_replace_dfs,
65+
module_map_replace,
6666
register_offload_parameter,
6767
update_parameter_data,
6868
)
@@ -397,7 +397,7 @@ def replace_with_compressed(module: Module) -> Module:
397397
return module
398398

399399
progress = tqdm(total=len(list(model.modules())))
400-
return module_replace_dfs(model, replace_with_compressed, progress=progress)
400+
return module_map_replace(model, replace_with_compressed, progress=progress)
401401

402402
def compress(
403403
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None

src/compressed_tensors/utils/helpers.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"pack_bitmasks",
4141
"unpack_bitmasks",
4242
"remove_suffix",
43-
"module_replace_dfs",
43+
"module_map_replace",
4444
]
4545

4646
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -339,12 +339,26 @@ def remove_suffix(value: str, suffix: str) -> str:
339339
return value[: -len(suffix)]
340340

341341

342-
def module_replace_dfs(
342+
def module_map_replace(
343343
module: torch.nn.Module,
344344
func: Callable[[torch.nn.Module], torch.nn.Module],
345-
pre: bool = True,
346345
progress: Union[bool, tqdm.tqdm] = False,
346+
pre: bool = True,
347347
) -> torch.nn.Module:
348+
"""
349+
Replaces modules in a given `torch.nn.Module` recursively using a provided function.
350+
351+
This function traverses the module hierarchy and applies the `func` transformation
352+
either before (`pre=True`) or after (`pre=False`) recursing into children modules.
353+
Optionally displays progress using tqdm.
354+
355+
:param module: root module to replace
356+
:param func: module mapping function
357+
:param progress: if True, display a tqdm progress bar.
358+
If a `tqdm.tqdm` instance is provided, the instance will be updated
359+
:param pre: if True, apply with pre-order, post-order otherwise
360+
:return: the modified module after applying the function to all submodules
361+
"""
348362
if progress is True:
349363
total = len(list(module.modules()))
350364
progress = tqdm.tqdm(total=total)
@@ -353,7 +367,7 @@ def module_replace_dfs(
353367
module = func(module)
354368

355369
for name, child in list(module.named_children()):
356-
module.add_module(name, module_replace_dfs(child, func, pre, progress))
370+
module.add_module(name, module_map_replace(child, func, pre, progress))
357371

358372
if not pre:
359373
module = func(module)

0 commit comments

Comments
 (0)