Skip to content

Commit 4a9b43f

Browse files
Fix optimizer support for Python <= 3.9 (#1379)
1 parent 9c9007a commit 4a9b43f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def load_state_dict(self, state_dict, move_to_device=True):
177177
raise ValueError("loaded state dict has a different number of parameter groups")
178178
param_lens = (len(g["params"]) for g in groups)
179179
saved_lens = (len(g["params"]) for g in saved_groups)
180-
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)):
180+
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
181181
raise ValueError(
182182
"loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
183183
)
@@ -188,7 +188,6 @@ def load_state_dict(self, state_dict, move_to_device=True):
188188
for old_id, p in zip(
189189
chain.from_iterable(g["params"] for g in saved_groups),
190190
chain.from_iterable(g["params"] for g in groups),
191-
strict=True,
192191
)
193192
}
194193

@@ -230,7 +229,7 @@ def update_group(group, new_group):
230229
new_group["params"] = group["params"]
231230
return new_group
232231

233-
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)]
232+
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
234233
self.__setstate__({"state": state, "param_groups": param_groups})
235234

236235
def to_gpu(self):

0 commit comments

Comments
 (0)