|
2 | 2 | #
|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
| 5 | +import torch |
| 6 | + |
5 | 7 | from bitsandbytes.optim.optimizer import Optimizer2State
|
6 | 8 |
|
| 9 | +_galore_available = False |
| 10 | +try: |
| 11 | + from galore_torch.galore_projector import GaLoreProjector |
| 12 | + |
| 13 | + _galore_available = True |
| 14 | +except ImportError: |
| 15 | + pass |
| 16 | + |
7 | 17 |
|
8 | 18 | class AdamW(Optimizer2State):
|
9 | 19 | def __init__(
|
@@ -127,6 +137,133 @@ def __init__(
|
127 | 137 | )
|
128 | 138 |
|
129 | 139 |
|
| 140 | +class GaLoreAdamW8bit(Optimizer2State): |
| 141 | + def __init__( |
| 142 | + self, |
| 143 | + params, |
| 144 | + lr=1e-3, |
| 145 | + betas=(0.9, 0.999), |
| 146 | + eps=1e-8, |
| 147 | + weight_decay=1e-2, |
| 148 | + amsgrad=False, |
| 149 | + optim_bits=8, |
| 150 | + args=None, |
| 151 | + min_8bit_size=4096, |
| 152 | + percentile_clipping=100, |
| 153 | + block_wise=True, |
| 154 | + is_paged=False, |
| 155 | + ): |
| 156 | + if not _galore_available: |
| 157 | + raise RuntimeError("The galore_torch package must be installed to use GaLoreAdamW8bit.") |
| 158 | + super().__init__( |
| 159 | + "adam", |
| 160 | + params, |
| 161 | + lr, |
| 162 | + betas, |
| 163 | + eps, |
| 164 | + weight_decay, |
| 165 | + optim_bits, |
| 166 | + args, |
| 167 | + min_8bit_size, |
| 168 | + percentile_clipping, |
| 169 | + block_wise, |
| 170 | + is_paged=is_paged, |
| 171 | + ) |
| 172 | + |
| 173 | + @torch.no_grad() |
| 174 | + def step(self, closure=None): |
| 175 | + """Performs a single optimization step. |
| 176 | +
|
| 177 | + Arguments: |
| 178 | + closure (callable, optional): A closure that reevaluates the model |
| 179 | + and returns the loss. |
| 180 | + """ |
| 181 | + loss = None |
| 182 | + if closure is not None: |
| 183 | + with torch.enable_grad(): |
| 184 | + loss = closure() |
| 185 | + |
| 186 | + overflows = [] |
| 187 | + |
| 188 | + if not self.initialized: |
| 189 | + self.check_overrides() |
| 190 | + self.to_gpu() # needed for fairseq pure fp16 training |
| 191 | + self.initialized = True |
| 192 | + |
| 193 | + # if self.is_paged: self.page_mng.prefetch_all() |
| 194 | + for gindex, group in enumerate(self.param_groups): |
| 195 | + for pindex, p in enumerate(group["params"]): |
| 196 | + if p.grad is None: |
| 197 | + continue |
| 198 | + state = self.state[p] |
| 199 | + |
| 200 | + if "step" not in state: |
| 201 | + state["step"] = 0 |
| 202 | + |
| 203 | + if "rank" in group: |
| 204 | + if "projector" not in state: |
| 205 | + state["projector"] = GaLoreProjector( |
| 206 | + group["rank"], |
| 207 | + update_proj_gap=group["update_proj_gap"], |
| 208 | + scale=group["scale"], |
| 209 | + proj_type=group["proj_type"], |
| 210 | + ) |
| 211 | + |
| 212 | + grad = state["projector"].project(p.grad, state["step"]) |
| 213 | + |
| 214 | + else: |
| 215 | + pass |
| 216 | + |
| 217 | + #### |
| 218 | + |
| 219 | + # GaLore Projection |
| 220 | + if "rank" in group: |
| 221 | + if "projector" not in state: |
| 222 | + state["projector"] = GaLoreProjector( |
| 223 | + group["rank"], |
| 224 | + update_proj_gap=group["update_proj_gap"], |
| 225 | + scale=group["scale"], |
| 226 | + proj_type=group["proj_type"], |
| 227 | + ) |
| 228 | + |
| 229 | + grad = state["projector"].project(p.grad, state["step"]) |
| 230 | + |
| 231 | + # suboptimal implementation |
| 232 | + # p.saved_data = p.data.clone() |
| 233 | + # p.data = grad.clone().to(p.data.dtype).to(p.data.device) |
| 234 | + # p.data.zero_() |
| 235 | + # p.grad = grad |
| 236 | + lor_update = torch.zeros_like( |
| 237 | + grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad |
| 238 | + ) |
| 239 | + lor_update.grad = grad |
| 240 | + |
| 241 | + if "state1" not in state: |
| 242 | + self.init_state(group, p, gindex, pindex) |
| 243 | + |
| 244 | + self.prefetch_state(p) |
| 245 | + |
| 246 | + if "rank" in group: |
| 247 | + self.update_step(group, p, gindex, pindex, return_updates=lor_update) |
| 248 | + |
| 249 | + # GaLore Projection Back |
| 250 | + p.data.add_(state["projector"].project_back(lor_update)) |
| 251 | + |
| 252 | + if "weight_decay" in group and group["weight_decay"] > 0: |
| 253 | + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) |
| 254 | + else: |
| 255 | + self.update_step(group, p, gindex, pindex) |
| 256 | + |
| 257 | + torch.cuda.synchronize() |
| 258 | + |
| 259 | + if self.is_paged: |
| 260 | + # all paged operation are asynchronous, we need |
| 261 | + # to sync to make sure all tensors are in the right state |
| 262 | + torch.cuda.synchronize() |
| 263 | + |
| 264 | + return loss |
| 265 | + |
| 266 | + |
130 | 267 | class AdamW32bit(Optimizer2State):
|
131 | 268 | def __init__(
|
132 | 269 | self,
|
|
0 commit comments