Skip to content

How does Keras MultiHeadAttention handle input feature dimension vs heads and reshape the output? #21219

Open
@SyedHasnat

Description

@SyedHasnat

Hi everyone,

I am trying to better understand how the Keras MultiHeadAttention layer handles the dimensions internally.

Suppose I input a tensor of shape (32, 24, 21) meaning (batch_size, time_steps, features) into the MultiHeadAttention layer, and I set the number of heads to 8.
Keras correctly outputs a tensor of shape (32, 24, 21), matching my input dimensions, but I'm confused about the internal dimension handling.

My understanding is:

  • If the input is (batch_size=2, time_steps=3, features=4) and we use num_heads=2,
  • Then after the linear projection, the queries (Q) will be shaped into (2, 3, 4),
  • Then separated into heads: (2, 2, 3, 2),
  • After transpose: (2, 2, 3, 2) → (2, 2, 3, 2),
  • Then attention scores (QKᵀ) will be (2, 2, 3, 3),
  • After applying softmax and multiplying by V, the output per head is (2, 2, 3, 2),
  • After merging heads, we get back to (2, 3, 4) by concatenating heads.

My confusion:
In my case, features=21, and heads=8.
Since 21 is not divisible by 8, how is Keras handling this? Normally, the feature dimension must be divisible by the number of heads (i.e., key_dim * num_heads = features).
So how does Keras map the 21 features into multiple heads internally, and how does it correctly recover the (32, 24, 21) output shape?

Would love a clarification on this!

Metadata

Metadata

Labels

type:supportUser is asking for help / asking an implementation question. Stackoverflow would be better suited.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions