Skip to content

The model "loses" or "degrades" its multimodality when you increase the size of the model #33

@F4k3r22

Description

@F4k3r22

The model "loses" or "degrades" its multimodality when you increase the size of the model, the maximum that can be increased is this:

model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=256,  # Increased transformer dimensions
        depth=8,  # More layers
        dim_head=64,
        heads=8,
    )
).to(device)

I noticed this because I wanted to train with this setup and I could never get multimodality in sampling during my training:

model = Transfusion(
    num_text_tokens=256,  # Increased for text vocabulary
    dim_latent=dim_latent,
    modality_default_shape=(16, 16),  # Adjusted for 256x256 images
    modality_encoder=encoder,
    modality_decoder=decoder,
    add_pos_emb=True,
    modality_num_dim=2,
    transformer=dict(
        dim=768,  # Increased transformer dimensions
        depth=12,  # More layers
        dim_head=64,
        heads=12,
    )
).to(device)

Maybe it's because of the transfusion_attn_mask, we can give more weight to the multimodal interactions of the model to compensate for the larger possible configurations

I think this could be done:

def transfusion_attn_mask(modalities: Int['b m 3']):
    modalities = modalities.long()
    
    def mask_mod(b, h, q_idx, kv_idx):

        causal_mask = causal(b, h, q_idx, kv_idx)
        

        modality_mask = torch.zeros_like(causal_mask, dtype=torch.bool)
        modality_batch = modalities[b]
        

        modality_attention_factor = 1.5  
        
        for mod_type, offset, length in modality_batch:

            current_modality_mask = modality(offset, length)(b, h, q_idx, kv_idx)
            

            if mod_type > 0: 
                current_modality_mask = current_modality_mask | current_modality_mask.transpose(-1, -2)
            

            modality_mask = modality_mask | current_modality_mask


        final_mask = causal_mask | (modality_mask * modality_attention_factor)
        
        return final_mask

    return mask_mod

Even if it's not the correct solution, I hope this observation helps you, have a nice day

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions