Skip to content

Commit be4f634

Browse files
committed
add a open clip wrapper that reduces even more work
1 parent 2a6b92c commit be4f634

File tree

4 files changed

+65
-1
lines changed

4 files changed

+65
-1
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,29 @@ embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = w
103103
# the embed_mask needs to be passed to the cross attention as key padding mask
104104
```
105105

106+
If you can identify the `CLIP` instance within the stable diffusion instance, you can also pass it directly to the `OpenClipEmbedWrapper` to gain everything you need on forward for the cross attention layers
107+
108+
ex.
109+
110+
```python
111+
from perfusion_pytorch import OpenClipEmbedWrapper
112+
113+
texts = [
114+
'a portrait of dog',
115+
'dog running through a green field',
116+
'a man walking his dog'
117+
]
118+
119+
wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
120+
text_encoder.clip,
121+
superclass_string = 'dog'
122+
)
123+
124+
enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)
125+
126+
# (3, 77, 512), (3, 77, 512), (3, 77), (3,)
127+
```
128+
106129
## Todo
107130

108131
- [ ] wire up with SD 1.5, starting with xiao's dreambooth-sd

perfusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from perfusion_pytorch.embedding import (
1010
EmbeddingWrapper,
11+
OpenClipEmbedWrapper,
1112
merge_embedding_wrappers
1213
)
1314

perfusion_pytorch/embedding.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,46 @@ def forward(
221221

222222
return EmbeddingReturn(embeds, superclass_embeds, embed_mask, concept_indices)
223223

224+
# a wrapper for clip
225+
# that automatically wraps the token embedding with new concept
226+
# and on forward, passes the concept embeddings + superclass concept embeddings through the text transformer + final layernorm
227+
# as well as make the forward pass the ids and superclass_ids through the modified text encoder twice (will attempt to substitute the nn.Embedding with an nn.Identity)
228+
229+
from open_clip import CLIP
230+
231+
class OpenClipEmbedWrapper(Module):
232+
@beartype
233+
def __init__(
234+
self,
235+
clip: CLIP,
236+
**embedding_wrapper_kwargs
237+
):
238+
super().__init__()
239+
self.wrapped_embed = EmbeddingWrapper(clip.token_embedding, **embedding_wrapper_kwargs)
240+
241+
self.text_transformer = nn.Sequential(
242+
clip.transformer,
243+
clip.ln_final
244+
)
245+
246+
def forward(
247+
self,
248+
x,
249+
**kwargs
250+
) -> EmbeddingWrapper:
251+
text_embeds, superclass_text_embeds, text_mask, concept_indices = self.wrapped_embed(x, **kwargs)
252+
253+
text_enc = self.text_transformer(text_embeds)
254+
255+
superclass_text_enc = None
256+
257+
if exists(superclass_text_embeds):
258+
superclass_text_enc = self.text_transformer(superclass_text_embeds)
259+
260+
return EmbeddingReturn(text_enc, superclass_text_embeds, text_mask, concept_indices)
261+
262+
# merging multiple embedding wrappers (with one concepts) into a merged embedding wrapper with multiple concepts
263+
224264
@beartype
225265
def merge_embedding_wrappers(
226266
*embeds: EmbeddingWrapper

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
install_requires=[
2020
'beartype',
2121
'einops>=0.6.1',
22-
'open-clip-torch>=2.0.0,<3.0.0',
22+
'open-clip-torch',
2323
'opt-einsum',
2424
'torch>=2.0'
2525
],

0 commit comments

Comments
 (0)