|
29 | 29 | # https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
|
30 | 30 | # NOTE: power-1 is linear
|
31 | 31 | # TODO: since QMAP_UNSIGNED is linear, perhaps doing affine quantize is faster?
|
32 |
| -QMAP_SIGNED = create_dynamic_map(True, 3, 4) |
33 |
| -QMAP_UNSIGNED = torch.linspace(0, 1, 17)[1:].tolist() # no zero |
| 32 | + |
| 33 | +# Lazy initialization to avoid meta device issues during import |
| 34 | +from functools import lru_cache |
| 35 | + |
| 36 | + |
| 37 | +@lru_cache(maxsize=1) |
| 38 | +def get_qmap_signed(): |
| 39 | + return create_dynamic_map(True, 3, 4) |
| 40 | + |
| 41 | + |
| 42 | +@lru_cache(maxsize=1) |
| 43 | +def get_qmap_unsigned(): |
| 44 | + return torch.linspace(0, 1, 17, device="cpu")[1:].tolist() # no zero |
34 | 45 |
|
35 | 46 |
|
36 | 47 | class OptimState4bit(TorchAOBaseTensor):
|
@@ -90,7 +101,9 @@ def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None):
|
90 | 101 |
|
91 | 102 | codes = torch.zeros(n_elems // 2, dtype=torch.uint8, device=device)
|
92 | 103 | scale = torch.zeros(n_elems // block_size, device=device)
|
93 |
| - qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device) |
| 104 | + qmap = torch.tensor( |
| 105 | + get_qmap_signed() if signed else get_qmap_unsigned(), device=device |
| 106 | + ) |
94 | 107 | return cls(codes, scale, qmap, signed, shape)
|
95 | 108 |
|
96 | 109 | def __repr__(self):
|
|
0 commit comments