How does Keras MultiHeadAttention handle input feature dimension vs heads and reshape the output? #21219
Labels
type:support
User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Hi everyone,
I am trying to better understand how the Keras
MultiHeadAttention
layer handles the dimensions internally.Suppose I input a tensor of shape (32, 24, 21) meaning (batch_size, time_steps, features) into the
MultiHeadAttention
layer, and I set the number of heads to 8.Keras correctly outputs a tensor of shape (32, 24, 21), matching my input dimensions, but I'm confused about the internal dimension handling.
My understanding is:
num_heads=2
,My confusion:
In my case, features=21, and heads=8.
Since 21 is not divisible by 8, how is Keras handling this? Normally, the feature dimension must be divisible by the number of heads (i.e.,
key_dim * num_heads = features
).So how does Keras map the 21 features into multiple heads internally, and how does it correctly recover the (32, 24, 21) output shape?
Would love a clarification on this!
The text was updated successfully, but these errors were encountered: