Skip to content

Commit 01bd0be

Browse files
authored
Make optim lazily intialize global state (#2277)
stack-info: PR: #2277, branch: drisspg/stack/60
1 parent d963a88 commit 01bd0be

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

torchao/optim/subclass_4bit.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,19 @@
2929
# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
3030
# NOTE: power-1 is linear
3131
# TODO: since QMAP_UNSIGNED is linear, perhaps doing affine quantize is faster?
32-
QMAP_SIGNED = create_dynamic_map(True, 3, 4)
33-
QMAP_UNSIGNED = torch.linspace(0, 1, 17)[1:].tolist() # no zero
32+
33+
# Lazy initialization to avoid meta device issues during import
34+
from functools import lru_cache
35+
36+
37+
@lru_cache(maxsize=1)
38+
def get_qmap_signed():
39+
return create_dynamic_map(True, 3, 4)
40+
41+
42+
@lru_cache(maxsize=1)
43+
def get_qmap_unsigned():
44+
return torch.linspace(0, 1, 17, device="cpu")[1:].tolist() # no zero
3445

3546

3647
class OptimState4bit(TorchAOBaseTensor):
@@ -90,7 +101,9 @@ def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None):
90101

91102
codes = torch.zeros(n_elems // 2, dtype=torch.uint8, device=device)
92103
scale = torch.zeros(n_elems // block_size, device=device)
93-
qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device)
104+
qmap = torch.tensor(
105+
get_qmap_signed() if signed else get_qmap_unsigned(), device=device
106+
)
94107
return cls(codes, scale, qmap, signed, shape)
95108

96109
def __repr__(self):

torchao/optim/subclass_8bit.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,18 @@
2626
c10d_functional = torch.ops.c10d_functional
2727
_c10d_functional = torch.ops._c10d_functional
2828

29-
QMAP_SIGNED = create_dynamic_map(signed=True)
30-
QMAP_UNSIGNED = create_dynamic_map(signed=False)
29+
# Lazy initialization to avoid meta device issues during import
30+
from functools import lru_cache
31+
32+
33+
@lru_cache(maxsize=1)
34+
def get_qmap_signed():
35+
return create_dynamic_map(signed=True)
36+
37+
38+
@lru_cache(maxsize=1)
39+
def get_qmap_unsigned():
40+
return create_dynamic_map(signed=False)
3141

3242

3343
class OptimState8bit(TorchAOBaseTensor):
@@ -79,7 +89,9 @@ def dequantize(self, output_dtype=None):
7989
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
8090
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
8191
scale = torch.zeros(codes.numel() // block_size, device=device)
82-
qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device)
92+
qmap = torch.tensor(
93+
get_qmap_signed() if signed else get_qmap_unsigned(), device=device
94+
)
8395
return cls(codes, scale, qmap, signed)
8496

8597
def __repr__(self):

0 commit comments

Comments
 (0)