-
Notifications
You must be signed in to change notification settings - Fork 54
Open
Description
The model "loses" or "degrades" its multimodality when you increase the size of the model, the maximum that can be increased is this:
model = Transfusion(
num_text_tokens=256, # Increased for text vocabulary
dim_latent=dim_latent,
modality_default_shape=(16, 16), # Adjusted for 256x256 images
modality_encoder=encoder,
modality_decoder=decoder,
add_pos_emb=True,
modality_num_dim=2,
transformer=dict(
dim=256, # Increased transformer dimensions
depth=8, # More layers
dim_head=64,
heads=8,
)
).to(device)
I noticed this because I wanted to train with this setup and I could never get multimodality in sampling during my training:
model = Transfusion(
num_text_tokens=256, # Increased for text vocabulary
dim_latent=dim_latent,
modality_default_shape=(16, 16), # Adjusted for 256x256 images
modality_encoder=encoder,
modality_decoder=decoder,
add_pos_emb=True,
modality_num_dim=2,
transformer=dict(
dim=768, # Increased transformer dimensions
depth=12, # More layers
dim_head=64,
heads=12,
)
).to(device)
Maybe it's because of the transfusion_attn_mask
, we can give more weight to the multimodal interactions of the model to compensate for the larger possible configurations
I think this could be done:
def transfusion_attn_mask(modalities: Int['b m 3']):
modalities = modalities.long()
def mask_mod(b, h, q_idx, kv_idx):
causal_mask = causal(b, h, q_idx, kv_idx)
modality_mask = torch.zeros_like(causal_mask, dtype=torch.bool)
modality_batch = modalities[b]
modality_attention_factor = 1.5
for mod_type, offset, length in modality_batch:
current_modality_mask = modality(offset, length)(b, h, q_idx, kv_idx)
if mod_type > 0:
current_modality_mask = current_modality_mask | current_modality_mask.transpose(-1, -2)
modality_mask = modality_mask | current_modality_mask
final_mask = causal_mask | (modality_mask * modality_attention_factor)
return final_mask
return mask_mod
Even if it's not the correct solution, I hope this observation helps you, have a nice day
Metadata
Metadata
Assignees
Labels
No labels