Description
Hi everyone,
I am trying to better understand how the Keras MultiHeadAttention
layer handles the dimensions internally.
Suppose I input a tensor of shape (32, 24, 21) meaning (batch_size, time_steps, features) into the MultiHeadAttention
layer, and I set the number of heads to 8.
Keras correctly outputs a tensor of shape (32, 24, 21), matching my input dimensions, but I'm confused about the internal dimension handling.
My understanding is:
- If the input is (batch_size=2, time_steps=3, features=4) and we use
num_heads=2
, - Then after the linear projection, the queries (Q) will be shaped into (2, 3, 4),
- Then separated into heads: (2, 2, 3, 2),
- After transpose: (2, 2, 3, 2) → (2, 2, 3, 2),
- Then attention scores (QKᵀ) will be (2, 2, 3, 3),
- After applying softmax and multiplying by V, the output per head is (2, 2, 3, 2),
- After merging heads, we get back to (2, 3, 4) by concatenating heads.
My confusion:
In my case, features=21, and heads=8.
Since 21 is not divisible by 8, how is Keras handling this? Normally, the feature dimension must be divisible by the number of heads (i.e., key_dim * num_heads = features
).
So how does Keras map the 21 features into multiple heads internally, and how does it correctly recover the (32, 24, 21) output shape?
Would love a clarification on this!