|
| 1 | +""" PyTorch MARS Optimizer |
| 2 | +
|
| 3 | +Code simplified from https://github.com/AGI-Arena/MARS |
| 4 | +
|
| 5 | +Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Models - https://arxiv.org/abs/2411.10438 |
| 6 | +
|
| 7 | +@article{yuan2024mars, |
| 8 | + title={MARS: Unleashing the Power of Variance Reduction for Training Large Models}, |
| 9 | + author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan}, |
| 10 | + journal={arXiv preprint arXiv:2411.10438}, |
| 11 | + year={2024} |
| 12 | +} |
| 13 | +""" |
| 14 | +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates |
| 15 | +# SPDX-License-Identifier: Apache-2.0 |
| 16 | +import math |
| 17 | + |
| 18 | +import torch |
| 19 | +from torch.optim.optimizer import Optimizer |
| 20 | + |
| 21 | + |
| 22 | +def mars_single_tensor( |
| 23 | + p, |
| 24 | + grad, |
| 25 | + exp_avg, |
| 26 | + exp_avg_sq, |
| 27 | + lr, |
| 28 | + weight_decay, |
| 29 | + beta1, |
| 30 | + beta2, |
| 31 | + last_grad, |
| 32 | + eps, |
| 33 | + step, |
| 34 | + gamma, |
| 35 | + mars_type, |
| 36 | + is_grad_2d, |
| 37 | + optimize_1d, |
| 38 | + lr_1d_factor, |
| 39 | + betas_1d, |
| 40 | +): |
| 41 | + # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para |
| 42 | + if optimize_1d or is_grad_2d: |
| 43 | + one_minus_beta1 = 1. - beta1 |
| 44 | + c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad) |
| 45 | + c_t_norm = torch.norm(c_t) |
| 46 | + if c_t_norm > 1.: |
| 47 | + c_t = c_t / c_t_norm |
| 48 | + exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1) |
| 49 | + if mars_type == "adamw": |
| 50 | + exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) |
| 51 | + bias_correction1 = 1.0 - beta1 ** step |
| 52 | + bias_correction2 = 1.0 - beta2 ** step |
| 53 | + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) |
| 54 | + update = p * weight_decay + (exp_avg / bias_correction1).div_(denom) |
| 55 | + elif mars_type == "lion": |
| 56 | + update = p * weight_decay + exp_avg.sign() |
| 57 | + else: |
| 58 | + assert False |
| 59 | + p.add_(update, alpha=-lr) |
| 60 | + else: |
| 61 | + beta1_1d, beta2_1d = betas_1d |
| 62 | + exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d) |
| 63 | + exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d) |
| 64 | + bias_correction1 = 1.0 - beta1_1d ** step |
| 65 | + bias_correction2 = 1.0 - beta2_1d ** step |
| 66 | + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) |
| 67 | + update = p * weight_decay + (exp_avg / bias_correction1).div_(denom) |
| 68 | + p.add_(update, alpha=-(lr * lr_1d_factor)) |
| 69 | + return exp_avg, exp_avg_sq |
| 70 | + |
| 71 | + |
| 72 | +class Mars(Optimizer): |
| 73 | + """ MARS Optimizer |
| 74 | +
|
| 75 | + Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Models |
| 76 | + https://arxiv.org/abs/2411.10438 |
| 77 | +
|
| 78 | + """ |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + params, |
| 82 | + lr=3e-3, |
| 83 | + betas=(0.9, 0.99), |
| 84 | + eps=1e-8, |
| 85 | + weight_decay=0., |
| 86 | + gamma=0.025, |
| 87 | + mars_type="adamw", |
| 88 | + optimize_1d=False, |
| 89 | + lr_1d_factor=1.0, |
| 90 | + betas_1d=None, |
| 91 | + ): |
| 92 | + if not 0.0 <= lr: |
| 93 | + raise ValueError("Invalid learning rate: {}".format(lr)) |
| 94 | + if not 0.0 <= eps: |
| 95 | + raise ValueError("Invalid epsilon value: {}".format(eps)) |
| 96 | + if not 0.0 <= betas[0] < 1.0: |
| 97 | + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
| 98 | + if not 0.0 <= betas[1] < 1.0: |
| 99 | + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
| 100 | + assert mars_type in ["adamw", "lion"], "MARS type not supported" |
| 101 | + |
| 102 | + defaults = dict( |
| 103 | + lr=lr, |
| 104 | + betas=betas, |
| 105 | + eps=eps, |
| 106 | + weight_decay=weight_decay, |
| 107 | + mars_type=mars_type, |
| 108 | + gamma=gamma, |
| 109 | + optimize_1d=optimize_1d, |
| 110 | + lr_1d_factor=lr_1d_factor, |
| 111 | + betas_1d=betas_1d or betas, |
| 112 | + ) |
| 113 | + super(Mars, self).__init__(params, defaults) |
| 114 | + |
| 115 | + @torch.no_grad() |
| 116 | + def step(self, closure=None): |
| 117 | + """Performs a single optimization step. |
| 118 | +
|
| 119 | + Arguments: |
| 120 | + closure (callable, optional): A closure that reevaluates the model |
| 121 | + and returns the loss. |
| 122 | + """ |
| 123 | + loss = None |
| 124 | + if closure is not None: |
| 125 | + with torch.enable_grad(): |
| 126 | + loss = closure() |
| 127 | + |
| 128 | + for group in self.param_groups: |
| 129 | + for p in group['params']: |
| 130 | + if p.grad is None: |
| 131 | + continue |
| 132 | + grad = p.grad |
| 133 | + if grad.is_sparse: |
| 134 | + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') |
| 135 | + |
| 136 | + state = self.state[p] |
| 137 | + # ('----- starting a parameter state', state.keys(), 'Length of state', len(state)) |
| 138 | + # State initialization |
| 139 | + if len(state) <= 1: |
| 140 | + state['step'] = 0 |
| 141 | + # Exponential moving average of gradient values |
| 142 | + state['exp_avg'] = torch.zeros_like(p) |
| 143 | + # Last Gradient |
| 144 | + state['last_grad'] = torch.zeros_like(p) |
| 145 | + # Exponential moving average of squared gradient values |
| 146 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 147 | + |
| 148 | + state['step'] += 1 |
| 149 | + step = state['step'] |
| 150 | + exp_avg = state['exp_avg'] |
| 151 | + exp_avg_sq = state['exp_avg_sq'] |
| 152 | + last_grad = state['last_grad'] |
| 153 | + lr = group['lr'] |
| 154 | + wd = group['weight_decay'] |
| 155 | + beta1, beta2 = group['betas'] |
| 156 | + is_grad_2d = grad.ndim >= 2 |
| 157 | + |
| 158 | + mars_single_tensor( |
| 159 | + p, |
| 160 | + grad, |
| 161 | + exp_avg, |
| 162 | + exp_avg_sq, |
| 163 | + lr, |
| 164 | + wd, |
| 165 | + beta1, |
| 166 | + beta2, |
| 167 | + last_grad, |
| 168 | + group['eps'], |
| 169 | + step, |
| 170 | + group['gamma'], |
| 171 | + mars_type=group['mars_type'], |
| 172 | + is_grad_2d=is_grad_2d, |
| 173 | + optimize_1d=group['optimize_1d'], |
| 174 | + lr_1d_factor=group['lr_1d_factor'], |
| 175 | + betas_1d=group['betas_1d'], |
| 176 | + ) |
| 177 | + |
| 178 | + state['last_grad'] = grad |
| 179 | + |
| 180 | + return loss |
0 commit comments