-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Closed
Description
diffusers/src/diffusers/models/transformers/transformer_cosmos.py
Lines 188 to 193 in 42077e6
# 4. Prepare for GQA | |
query_idx = torch.tensor(query.size(3), device=query.device) | |
key_idx = torch.tensor(key.size(3), device=key.device) | |
value_idx = torch.tensor(value.size(3), device=value.device) | |
key = key.repeat_interleave(query_idx // key_idx, dim=3) | |
value = value.repeat_interleave(query_idx // value_idx, dim=3) |
# 4. Prepare for GQA
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
Speedup ~10% here in Cosmos2TextToImagePipeline and Cosmos2VideoToWorldPipeline.
Metadata
Metadata
Assignees
Labels
No labels