-
-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
questionUser queriesUser queries
Description
In the LoRA implementation, the way the API is setup is such that we have to keep the abstract values consistent for Array
s:
quax/quax/examples/lora/_core.py
Lines 89 to 90 in b72049d
def aval(self): | |
return jax.core.ShapedArray(self.w.shape, self.w.dtype) |
However, when counting parameters later on this becomes annoying and one has to resort to ugly solutions because the PyTree
obscures the actual internal parameters of the LoraArray
, and thus we arrive at a huge overestimate.
How do you think this can be resolved?
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries