Skip to content

Commit 38c2e78

Browse files
committed
Add docs to the requantize(...) function explaining why it was copied from optimum-quanto.
1 parent d11dc6d commit 38c2e78

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

invokeai/backend/requantize.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33
import torch
44
from optimum.quanto.quantize import _quantize_submodule
55

6-
# def custom_freeze(model: torch.nn.Module):
7-
# for name, m in model.named_modules():
8-
# if isinstance(m, QModuleMixin):
9-
# m.weight =
10-
# m.freeze()
11-
126

137
def requantize(
148
model: torch.nn.Module,
159
state_dict: Dict[str, Any],
1610
quantization_map: Dict[str, Dict[str, str]],
17-
device: torch.device = None,
11+
device: torch.device | None = None,
1812
):
13+
"""This function was initially copied from:
14+
https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101
15+
16+
The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the
17+
weights are about to be loaded from a state_dict.
18+
19+
TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library.
20+
"""
1921
if device is None:
2022
device = next(model.parameters()).device
2123
if device.type == "meta":
@@ -45,6 +47,7 @@ def move_tensor(t, device):
4547
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
4648
for name, param in m.named_buffers(recurse=False):
4749
setattr(m, name, move_tensor(param, "cpu"))
50+
4851
# Freeze model and move to target device
4952
# freeze(model)
5053
# model.to(device)

0 commit comments

Comments
 (0)