Skip to content

MultiHeadAttention's use_causal_mask is broken #21284

Closed
@pfekin

Description

@pfekin

There is leakage of forward embeddings when not calling the MultiHeadAttention layer with a mask and using use_causal_mask=True instead.
I get +0.99 accuracy on a randomly generated validation dataset using Colab.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions