@@ -25,12 +25,12 @@ def exists(val):
2525@beartype
2626@torch .no_grad ()
2727def calculate_input_covariance (
28- open_clip : OpenClipAdapter ,
28+ clip : OpenClipAdapter ,
2929 texts : List [str ],
3030 batch_size = 32 ,
3131 ** cov_kwargs
3232):
33- embeds , mask = open_clip .embed_texts (texts )
33+ embeds , mask = clip .embed_texts (texts )
3434
3535 num_batches = ceil (len (texts ) / batch_size )
3636
@@ -42,13 +42,45 @@ def calculate_input_covariance(
4242 start_index = batch_ind * batch_size
4343 batch_texts = texts [start_index :(start_index + batch_size )]
4444
45- embeds , mask = open_clip .embed_texts (batch_texts )
45+ embeds , mask = clip .embed_texts (batch_texts )
4646 all_embeds .append (embeds [mask ])
4747
48- all_embeds = torch .cat (( all_embeds ) , dim = 0 )
48+ all_embeds = torch .cat (all_embeds , dim = 0 )
4949
5050 return einsum ('n d, n e -> d e' , all_embeds , all_embeds ) / length
5151
52+ @beartype
53+ def find_first_index (
54+ indices : IndicesTensor ,
55+ concept_or_superclass_id : int
56+ ):
57+ """
58+ for deriving the concept_indices to be passed into the Rank1EditModule
59+ """
60+
61+ edge = (indices == concept_or_superclass_id ).cumsum (dim = - 1 ) # [1, 3, 5, 4, 1, 1], 4 -> [0, 0, 0, 1, 0, 0, 0] -> [0, 0, 0, 1, 1, 1]
62+ return edge .sum (dim = - 1 )
63+
64+ @beartype
65+ def return_text_enc_with_concept_and_superclass (
66+ text_ids : IndicesTensor ,
67+ concept_id : int ,
68+ superclass_id : int ,
69+ clip : Optional [OpenClipAdapter ] = None
70+ ):
71+ batch = text_ids .shape [0 ]
72+ batch_arange = torch .arange (batch , device = text_ids .device )
73+ concept_indices = find_first_index (text_ids , concept_id )
74+ text_ids_with_superclass = text_ids [batch_arange , concept_indices ] = superclass_ids
75+
76+ if not exists (clip ):
77+ return text_ids , concept_indices , text_ids_with_superclass
78+
79+ concept_text_enc = clip .embed_texts (text_ids )
80+ superclass_text_enc = clip .embed_texts (text_ids_with_superclass )
81+
82+ return concept_text_enc , concept_indices , superclass_text_enc
83+
5284# a module that wraps the keys and values projection of the cross attentions to text encodings
5385
5486class Rank1EditModule (Module ):
0 commit comments