You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+3-1Lines changed: 3 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -82,12 +82,14 @@ values = wrapped_to_values(
82
82
83
83
## Todo
84
84
85
-
-[ ] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
86
85
-[ ] handle rank-1 update for multiple concepts
87
86
-[x] handle training with multiple concepts
88
87
-[ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
88
+
-[ ] accept multiple concept indices
89
89
-[ ] offer a magic function that automatically tries to wire up the cross attention by looking for appropriately named `nn.Linear` and auto-inferring which ones are keys or values
90
90
91
+
-[x] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
92
+
-[x] offer function for merging `Rank1EditModule`s
91
93
-[x] add the zero-shot masking of concept proposed in paper
92
94
-[x] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
93
95
-[x] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
returnW_em_orthogonal_term+sigmoid_term*rearrange(o, 'd -> 1 1 d')
294
+
295
+
# for merging trained Rank1EditModule(s) above
296
+
297
+
@beartype
298
+
defmerge_rank1_edit_modules(
299
+
*modules: Rank1EditModule
300
+
) ->Rank1EditModule:
301
+
302
+
assertall([m.initted.item() forminmodules]), 'all modules must be initialized and ideally trained'
303
+
assertlen(set([m.concept_outputs.shape[-1] forminmodules])) ==1, 'concept output dimension must be the same'
304
+
assertlen(set([m.is_key_projforminmodules])) ==1, 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'
0 commit comments