@@ -247,7 +247,7 @@ def __init__(
247
247
super ().__init__ ()
248
248
self .text_emb = nn .Embedding (num_text_tokens , dim_text )
249
249
self .text_pos_emb = nn .Embedding (text_seq_len , dim_text )
250
- self .text_transformer = Transformer (causal = False , seq_len = text_seq_len , dim = dim_text , depth = text_enc_depth , heads = text_heads )
250
+ self .text_transformer = Transformer (causal = False , seq_len = text_seq_len , dim = dim_text , depth = text_enc_depth , heads = text_heads , rotary_emb = False )
251
251
self .to_text_latent = nn .Linear (dim_text , dim_latent , bias = False )
252
252
253
253
assert visual_image_size % visual_patch_size == 0 , 'Image dimensions must be divisible by the patch size.'
@@ -257,7 +257,7 @@ def __init__(
257
257
self .visual_patch_size = visual_patch_size
258
258
self .to_visual_embedding = nn .Linear (patch_dim , dim_image )
259
259
self .visual_pos_emb = nn .Embedding (num_patches , dim_image )
260
- self .visual_transformer = Transformer (causal = False , seq_len = num_patches , dim = dim_image , depth = visual_enc_depth , heads = visual_heads )
260
+ self .visual_transformer = Transformer (causal = False , seq_len = num_patches , dim = dim_image , depth = visual_enc_depth , heads = visual_heads , rotary_emb = False )
261
261
self .to_visual_latent = nn .Linear (dim_image , dim_latent , bias = False )
262
262
263
263
self .temperature = nn .Parameter (torch .tensor (1. ))
0 commit comments