-
-
Notifications
You must be signed in to change notification settings - Fork 5
Description
I think one of the main motives for LoRA is to reduce memory consumption—certainly that's my motive. I'm already using gradient checkpointing and AdaFactor so the main thing I want from LoRA is to reduce the size of the gradient pytree itself. However, unless I'm quite confused, in a trivial setup like:
class DummyModel(eqx.Module, Generic[Dim1, Dim2, Float]):
tmp: eqx.nn.Linear[Dim1, Dim2, Float]
def __init__(self, dim1: Dim1, dim2: Dim2, dtype: type[Float], key: jax.Array) -> None:
self.tmp = eqx.nn.Linear(dim1, dim2, dtype=dtype, key=key)
def __call__(self, ndarray: ndarray[Dim1, Float]) -> ndarray[Dim2, Float]:
return self.tmp(ndarray)
@eqx.filter_value_and_grad
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def grad_fn(m: DummyModel[Dim1, Dim2, Float], y: ndarray[Dim1, Float]) -> ndarray[Float]:
m = quax.quaxify(m)
return jnp.square(jnp.mean(m(y)) - 0)
def main():
x = DummyModel[Dim1T, Dim2T, np.float32](4096, 4096, np.float32, jax.random.PRNGKey(0))
loraed = loraify(x, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
return grad_fn(loraed, np.random.rand(4096))
the returned grads include a full Dim1 x Dim2
array of zeros for _w
. Almost all the values in the gradient pytree are zero (for typical LoRAs) and this is wasted memory.
I thought perhaps I could get around this by replacing jax.lax.stop_gradient
in LoraArray
with something like:
@jax.custom_jvp
def symbolic_stop_gradient(x: A) -> A:
return x
@symbolic_stop_gradient.defjvp
def symbolic_stop_gradient_jvp(primals: tuple[ndarray[*Shape, Float]], tangents: tuple[ndarray[*Shape, Float]]):
return primals[0], Zero(primals[0].shape, primals[0].dtype)
but that produces the following error:
TypeError: Custom JVP rule symbolic_stop_gradient_jvp for function symbolic_stop_gradient must produce primal and tangent outputs with equal container (pytree) structures, but got PyTreeDef(*) and PyTreeDef(CustomNode(Zero[(), ('_shape', '_dtype'), ((4096, 4096), dtype('float32'))], [])) respectively.
Is there a reasonable way to use quax to implement LoRA in a way that doesn't allocate tons of space for zeros?
(I guess it's mildly possible that JAX optimizes out this allocation behind the scenes if the gradient pytree is "consumed" inside the same JIT where the gradients are produced, but I assume it's not quite that clever.)
Thanks.