@@ -5,7 +5,7 @@ const AA{N,T} = AbstractArray{T,N}
5
5
"""
6
6
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
7
7
8
- Multihead dot product attention used in transformer architectures.
8
+ Multihead dot product attention used in transformer architectures.
9
9
10
10
The input arrays must have the first two dimensions given by the number of features
11
11
and the sequence length, then an arbitrary number of batch dimensions or none.
@@ -23,15 +23,15 @@ See also [`dot_product_attention_scores`](@ref) if you only need the attention s
23
23
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
24
24
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
25
25
It will be added to the attention scores before applying the softmax. Default `nothing`.
26
- - `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
27
- Default `identity` (no dropout).
26
+ - `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
27
+ Default `identity` (no dropout).
28
28
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
29
29
The mask is applied to the attention scores just before the softmax.
30
30
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
31
31
- `nheads`: Number of heads to split the input arrays into. Default `1`.
32
32
33
33
# Examples
34
-
34
+
35
35
```julia
36
36
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
37
37
y, α = dot_product_attention(q, k, v)
@@ -49,13 +49,34 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) wh
49
49
return x, α
50
50
end
51
51
52
- function dot_product_attention (q:: AA3 , k:: AA3 , v:: AA3 , bias= nothing ;
52
+ function dot_product_attention (q:: AA3 , k:: AA3 , v:: AA3 , bias= nothing ;
53
53
fdrop= identity, mask= nothing , nheads= 1 )
54
54
55
- (size (q, 3 ) == size (k, 3 ) == size (v, 3 )) || throw (ArgumentError (" Batch dimensions have to be the same." ))
56
- size (q, 1 ) == size (k, 1 ) || throw (ArgumentError (" First dimension in query and key has to be the same." ))
57
- size (k, 2 ) == size (v, 2 ) || throw (ArgumentError (" Second dimension in key and value has to be the same." ))
58
-
55
+ (all (size .((q, k, v), 1 ) .% nheads .== 0 )) || throw (ArgumentError ("""
56
+ First dimension in query, key and value must be divisible by `nheads`.
57
+ Instead:
58
+ - size(q): $(size (q))
59
+ - size(k): $(size (q))
60
+ - size(v): $(size (q))
61
+ - nheads: $nheads
62
+ """ ))
63
+ (size (q, 3 ) == size (k, 3 ) == size (v, 3 )) || throw (ArgumentError ("""
64
+ Batch dimensions have to be the same. Instead:
65
+ - size(q): $(size (q))
66
+ - size(k): $(size (k))
67
+ - size(v): $(size (v))
68
+ """ ))
69
+ size (q, 1 ) == size (k, 1 ) || throw (ArgumentError ("""
70
+ First dimension in query and key has to be the same. Instead:
71
+ - size(q): $(size (q))
72
+ - size(k): $(size (k))
73
+ """ ))
74
+ size (k, 2 ) == size (v, 2 ) || throw (ArgumentError ("""
75
+ Second dimension in key and value has to be the same. Instead:
76
+ - size(k): $(size (k))
77
+ - size(v): $(size (v))
78
+ """ ))
79
+
59
80
# Multihead attention. TODO create fastpath for singlehead attention.
60
81
q, k, v = split_heads .((q, k, v), nheads)
61
82
x, α = _dot_product_attention (q, k, v, bias, fdrop, mask)
@@ -69,7 +90,7 @@ function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
69
90
70
91
α = dot_product_attention_scores (q, k, bias; fdrop, mask)
71
92
# [α] = [kv_len, q_len, nheads, batch_size]
72
-
93
+
73
94
# The following permutedims and batched_mul are equivalent to
74
95
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
75
96
vt = permutedims (v, (1 , 3 , 2 , 4 ))
83
104
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
84
105
85
106
Return the attention scores for the [`dot_product_attention`](@ref).
86
- Input arrays must have dimensions
107
+ Input arrays must have dimensions
87
108
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
88
109
89
110
See [`dot_product_attention`](@ref) for more details.
90
111
"""
91
- function dot_product_attention_scores (q:: AA4{T} , k:: AA4{T} , bias= nothing ;
112
+ function dot_product_attention_scores (q:: AA4{T} , k:: AA4{T} , bias= nothing ;
92
113
fdrop= identity, mask= nothing ) where T
93
114
94
115
# The following permutedims and batched_mul are equivalent to
@@ -100,7 +121,7 @@ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
100
121
101
122
logits = apply_attn_bias (logits, bias)
102
123
logits = apply_attn_mask (logits, mask)
103
-
124
+
104
125
α = softmax (logits, dims= 1 )
105
126
return fdrop (α)
106
127
end
@@ -109,7 +130,6 @@ apply_attn_bias(logits, bias::Nothing) = logits
109
130
110
131
apply_attn_bias (logits, bias) = logits .+ bias
111
132
112
-
113
133
apply_attn_mask (logits, mask:: Nothing ) = logits
114
134
115
135
function apply_attn_mask (logits, mask)
@@ -118,11 +138,11 @@ function apply_attn_mask(logits, mask)
118
138
end
119
139
120
140
121
- """
141
+ """
122
142
make_causal_mask(x, dims=2)
123
143
124
144
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
125
- Its elements are set such that `m[i, j] == i ≤ j`.
145
+ Its elements are set such that `m[i, j] == i ≤ j`.
126
146
127
147
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
128
148
"""
@@ -141,4 +161,3 @@ join_heads(x) = reshape(x, :, size(x)[3:end]...)
141
161
@non_differentiable make_causal_mask (:: Any... )
142
162
@non_differentiable trues_like (:: Any... )
143
163
@non_differentiable falses_like (:: Any... )
144
-
0 commit comments