|
14 | 14 |
|
15 | 15 | import warnings
|
16 | 16 | from functools import wraps
|
17 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union |
| 17 | +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional |
18 | 18 |
|
19 | 19 | import numpy
|
20 | 20 | import torch
|
21 |
| -import tqdm |
22 | 21 | from transformers import AutoConfig
|
23 | 22 |
|
24 | 23 |
|
|
39 | 38 | "shard_tensor",
|
40 | 39 | "pack_bitmasks",
|
41 | 40 | "unpack_bitmasks",
|
42 |
| - "remove_suffix", |
43 |
| - "module_map_replace", |
44 | 41 | ]
|
45 | 42 |
|
46 | 43 | FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -331,48 +328,3 @@ def unpack_bitmasks(
|
331 | 328 | )
|
332 | 329 |
|
333 | 330 | return unpacked_bitmasks_torch
|
334 |
| - |
335 |
| - |
336 |
| -def remove_suffix(value: str, suffix: str) -> str: |
337 |
| - # can replace with str.removesuffix in python3.9+ |
338 |
| - assert value.endswith(suffix) |
339 |
| - return value[: -len(suffix)] |
340 |
| - |
341 |
| - |
342 |
| -def module_map_replace( |
343 |
| - module: torch.nn.Module, |
344 |
| - func: Callable[[torch.nn.Module], torch.nn.Module], |
345 |
| - progress: Union[bool, tqdm.tqdm] = False, |
346 |
| - pre: bool = True, |
347 |
| -) -> torch.nn.Module: |
348 |
| - """ |
349 |
| - Replaces modules in a given `torch.nn.Module` recursively using a provided function. |
350 |
| -
|
351 |
| - This function traverses the module hierarchy and applies the `func` transformation |
352 |
| - either before (`pre=True`) or after (`pre=False`) recursing into children modules. |
353 |
| - Optionally displays progress using tqdm. |
354 |
| -
|
355 |
| - :param module: root module to replace |
356 |
| - :param func: module mapping function |
357 |
| - :param progress: if True, display a tqdm progress bar. |
358 |
| - If a `tqdm.tqdm` instance is provided, the instance will be updated |
359 |
| - :param pre: if True, apply with pre-order, post-order otherwise |
360 |
| - :return: the modified module after applying the function to all submodules |
361 |
| - """ |
362 |
| - if progress is True: |
363 |
| - total = len(list(module.modules())) |
364 |
| - progress = tqdm.tqdm(total=total) |
365 |
| - |
366 |
| - if pre: |
367 |
| - module = func(module) |
368 |
| - |
369 |
| - for name, child in list(module.named_children()): |
370 |
| - module.add_module(name, module_map_replace(child, func, pre, progress)) |
371 |
| - |
372 |
| - if not pre: |
373 |
| - module = func(module) |
374 |
| - |
375 |
| - if isinstance(progress, tqdm.tqdm): |
376 |
| - progress.update(1) |
377 |
| - |
378 |
| - return module |
0 commit comments