Skip to content

Parameter counting when quaxified layers are invoked #66

@neel04

Description

@neel04

In the LoRA implementation, the way the API is setup is such that we have to keep the abstract values consistent for Arrays:

def aval(self):
return jax.core.ShapedArray(self.w.shape, self.w.dtype)

However, when counting parameters later on this becomes annoying and one has to resort to ugly solutions because the PyTree obscures the actual internal parameters of the LoraArray, and thus we arrive at a huge overestimate.

How do you think this can be resolved?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions