Skip to content

LoRA that doesn't require memory for zero gradients of the underlying matrices #28

@colehaus

Description

@colehaus

I think one of the main motives for LoRA is to reduce memory consumption—certainly that's my motive. I'm already using gradient checkpointing and AdaFactor so the main thing I want from LoRA is to reduce the size of the gradient pytree itself. However, unless I'm quite confused, in a trivial setup like:

class DummyModel(eqx.Module, Generic[Dim1, Dim2, Float]):
    tmp: eqx.nn.Linear[Dim1, Dim2, Float]

    def __init__(self, dim1: Dim1, dim2: Dim2, dtype: type[Float], key: jax.Array) -> None:
        self.tmp = eqx.nn.Linear(dim1, dim2, dtype=dtype, key=key)

    def __call__(self, ndarray: ndarray[Dim1, Float]) -> ndarray[Dim2, Float]:
        return self.tmp(ndarray)

@eqx.filter_value_and_grad
@ft.partial(eqx.filter_checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def grad_fn(m: DummyModel[Dim1, Dim2, Float], y: ndarray[Dim1, Float]) -> ndarray[Float]:
    m = quax.quaxify(m)
    return jnp.square(jnp.mean(m(y)) - 0)
    
def main():
    x = DummyModel[Dim1T, Dim2T, np.float32](4096, 4096, np.float32, jax.random.PRNGKey(0))
    loraed = loraify(x, rank=64, scale=0.1, key=jax.random.PRNGKey(1))
    return grad_fn(loraed, np.random.rand(4096))

the returned grads include a full Dim1 x Dim2 array of zeros for _w. Almost all the values in the gradient pytree are zero (for typical LoRAs) and this is wasted memory.

I thought perhaps I could get around this by replacing jax.lax.stop_gradient in LoraArray with something like:

@jax.custom_jvp
def symbolic_stop_gradient(x: A) -> A:
    return x


@symbolic_stop_gradient.defjvp
def symbolic_stop_gradient_jvp(primals: tuple[ndarray[*Shape, Float]], tangents: tuple[ndarray[*Shape, Float]]):
    return primals[0], Zero(primals[0].shape, primals[0].dtype)

but that produces the following error:

TypeError: Custom JVP rule symbolic_stop_gradient_jvp for function symbolic_stop_gradient must produce primal and tangent outputs with equal container (pytree) structures, but got PyTreeDef(*) and PyTreeDef(CustomNode(Zero[(), ('_shape', '_dtype'), ((4096, 4096), dtype('float32'))], [])) respectively.

Is there a reasonable way to use quax to implement LoRA in a way that doesn't allocate tons of space for zeros?

(I guess it's mildly possible that JAX optimizes out this allocation behind the scenes if the gradient pytree is "consumed" inside the same JIT where the gradients are produced, but I assume it's not quite that clever.)

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions