Skip to content

Commit b6645fc

Browse files
authored
Add dropout & attention tests for AMDGPU (#472)
* Add dropout & attention tests for AMDGPU Refactor code a bit and add more detailed error messages. * Print-out AMDGPU versioninfo * Print version
1 parent 09347ea commit b6645fc

File tree

6 files changed

+111
-20
lines changed

6 files changed

+111
-20
lines changed

src/attention.jl

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const AA{N,T} = AbstractArray{T,N}
55
"""
66
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
77
8-
Multihead dot product attention used in transformer architectures.
8+
Multihead dot product attention used in transformer architectures.
99
1010
The input arrays must have the first two dimensions given by the number of features
1111
and the sequence length, then an arbitrary number of batch dimensions or none.
@@ -23,15 +23,15 @@ See also [`dot_product_attention_scores`](@ref) if you only need the attention s
2323
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
2424
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
2525
It will be added to the attention scores before applying the softmax. Default `nothing`.
26-
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
27-
Default `identity` (no dropout).
26+
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
27+
Default `identity` (no dropout).
2828
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
2929
The mask is applied to the attention scores just before the softmax.
3030
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
3131
- `nheads`: Number of heads to split the input arrays into. Default `1`.
3232
3333
# Examples
34-
34+
3535
```julia
3636
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
3737
y, α = dot_product_attention(q, k, v)
@@ -49,13 +49,34 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) wh
4949
return x, α
5050
end
5151

52-
function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
52+
function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
5353
fdrop=identity, mask=nothing, nheads=1)
5454

55-
(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same."))
56-
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
57-
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))
58-
55+
(all(size.((q, k, v), 1) .% nheads .== 0)) || throw(ArgumentError("""
56+
First dimension in query, key and value must be divisible by `nheads`.
57+
Instead:
58+
- size(q): $(size(q))
59+
- size(k): $(size(q))
60+
- size(v): $(size(q))
61+
- nheads: $nheads
62+
"""))
63+
(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("""
64+
Batch dimensions have to be the same. Instead:
65+
- size(q): $(size(q))
66+
- size(k): $(size(k))
67+
- size(v): $(size(v))
68+
"""))
69+
size(q, 1) == size(k, 1) || throw(ArgumentError("""
70+
First dimension in query and key has to be the same. Instead:
71+
- size(q): $(size(q))
72+
- size(k): $(size(k))
73+
"""))
74+
size(k, 2) == size(v, 2) || throw(ArgumentError("""
75+
Second dimension in key and value has to be the same. Instead:
76+
- size(k): $(size(k))
77+
- size(v): $(size(v))
78+
"""))
79+
5980
# Multihead attention. TODO create fastpath for singlehead attention.
6081
q, k, v = split_heads.((q, k, v), nheads)
6182
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
@@ -69,7 +90,7 @@ function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
6990

7091
α = dot_product_attention_scores(q, k, bias; fdrop, mask)
7192
# [α] = [kv_len, q_len, nheads, batch_size]
72-
93+
7394
# The following permutedims and batched_mul are equivalent to
7495
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
7596
vt = permutedims(v, (1, 3, 2, 4))
@@ -83,12 +104,12 @@ end
83104
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
84105
85106
Return the attention scores for the [`dot_product_attention`](@ref).
86-
Input arrays must have dimensions
107+
Input arrays must have dimensions
87108
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
88109
89110
See [`dot_product_attention`](@ref) for more details.
90111
"""
91-
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
112+
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
92113
fdrop=identity, mask=nothing) where T
93114

94115
# The following permutedims and batched_mul are equivalent to
@@ -100,7 +121,7 @@ function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
100121

101122
logits = apply_attn_bias(logits, bias)
102123
logits = apply_attn_mask(logits, mask)
103-
124+
104125
α = softmax(logits, dims=1)
105126
return fdrop(α)
106127
end
@@ -109,7 +130,6 @@ apply_attn_bias(logits, bias::Nothing) = logits
109130

110131
apply_attn_bias(logits, bias) = logits .+ bias
111132

112-
113133
apply_attn_mask(logits, mask::Nothing) = logits
114134

115135
function apply_attn_mask(logits, mask)
@@ -118,11 +138,11 @@ function apply_attn_mask(logits, mask)
118138
end
119139

120140

121-
"""
141+
"""
122142
make_causal_mask(x, dims=2)
123143
124144
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
125-
Its elements are set such that `m[i, j] == i ≤ j`.
145+
Its elements are set such that `m[i, j] == i ≤ j`.
126146
127147
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
128148
"""
@@ -141,4 +161,3 @@ join_heads(x) = reshape(x, :, size(x)[3:end]...)
141161
@non_differentiable make_causal_mask(::Any...)
142162
@non_differentiable trues_like(::Any...)
143163
@non_differentiable falses_like(::Any...)
144-

test/amd/attention.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
@testset "Compare CPU & GPU" begin
2+
n = 15
3+
lenq = 3
4+
lenkv = 4
5+
for batch_size in [(), 1, 2, (2, 1, 3)], nheads in [1, 3, 5]
6+
q = AMDGPU.rand(Float32, n, lenq, batch_size...)
7+
k = AMDGPU.rand(Float32, n, lenkv, batch_size...)
8+
v = AMDGPU.rand(Float32, n, lenkv, batch_size...)
9+
y, α = @inferred dot_product_attention(q, k, v; nheads)
10+
11+
@test y isa ROCArray{Float32}
12+
@test size(y) == (n, lenq, batch_size...)
13+
@test size(α) == (lenkv, lenq, nheads, batch_size...)
14+
@test sum(Array(α), dims=1) ones(1, lenq, nheads, batch_size...)
15+
16+
qh = rand(Float32, n, lenq, batch_size...)
17+
kh = rand(Float32, n, lenkv, batch_size...)
18+
vh = rand(Float32, n, lenkv, batch_size...)
19+
gputest(
20+
(x...) -> dot_product_attention(x...; nheads)[1], qh, kh, vh;
21+
atol=1f-5)
22+
end
23+
end
24+
25+
@testset "Mask" begin
26+
x = AMDGPU.rand(Float32, 4, 2, 3, 1)
27+
mask = make_causal_mask(x, dims=3)
28+
@test mask isa ROCArray{Bool}
29+
α = dot_product_attention_scores(x, x; mask)
30+
31+
α_host, mask_host = Array.((α, mask))
32+
@test all((α_host[:, :, 1, 1] .> 0) .== mask_host)
33+
@test all((α_host[:, :, 2, 1] .> 0) .== mask_host)
34+
end
35+
36+
@testset "Dropout" begin
37+
q = k = v = AMDGPU.rand(Float32, 10, 10, 10)
38+
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
39+
y, α = dot_product_attention(
40+
q, k, v; nheads=2, fdrop=x -> dropout(x, 0.5))
41+
@test 0.6 > mean(>(0), α) > 0.4
42+
end

test/amd/dropout.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@testset "Test API" begin
2+
x = AMDGPU.randn(Float32, 3, 4)
3+
@test size(@inferred dropout(x, 0.1)) == (3, 4)
4+
@test size(@inferred dropout(x, 0.2; dims=2)) == (3, 4)
5+
@test size(@inferred dropout(x, 0.3; dims=(1, 2))) == (3, 4)
6+
7+
rng = AMDGPU.rocRAND.default_rng()
8+
@test size(@inferred dropout(rng, x, 0.1)) == (3, 4)
9+
@test size(@inferred dropout(rng, x, 0.1; dims=2)) == (3, 4)
10+
11+
# Values
12+
d45 = dropout(AMDGPU.ones(100, 100, 100), 0.45)
13+
@test mean(d45) 1 atol=1e-2
14+
dpi2 = dropout(AMDGPU.fill(1f0 * pi, 1000), 0.2)
15+
@test sort(unique(Array(dpi2))) [0, 5 * pi / 4]
16+
d33 = dropout(AMDGPU.fill(3f0, 10, 1000), 0.3, dims=2)
17+
@test sort(unique(vec(Array(d33)))) [0, 3 / (1 - 0.3)]
18+
19+
@test Zygote.gradient(x -> sum(dropout(x, 0.1)), x)[1] isa ROCArray{Float32}
20+
end

test/amd/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,11 @@ end
5252
@testset "Activations" begin
5353
include("activations.jl")
5454
end
55+
56+
@testset "Dropout" begin
57+
include("dropout.jl")
58+
end
59+
60+
@testset "Attention" begin
61+
include("attention.jl")
62+
end

test/attention.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ end
3636
@testset "mask" begin
3737
q = rand(4, 2, 3, 1)
3838
k = rand(4, 2, 5, 1)
39-
39+
4040
mask = rand(Bool, (5, 3))
4141
α = dot_product_attention_scores(q, k; mask)
4242
@test all((α[:,:,1,1].> 0) .== mask)
@@ -53,7 +53,7 @@ end
5353

5454
@testset "dropout" begin
5555
q = k = v = rand(10, 10, 10)
56-
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
56+
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
5757
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))
5858
@test 0.6 > mean(>(0), α) > 0.4
5959
end
@@ -63,7 +63,7 @@ end
6363
k = v = rand(4, 3, 1)
6464
bias = randn(3, 5)
6565
y, α = dot_product_attention(q, k, v, bias; nheads=2)
66-
@test size(α) == (3, 5, 2, 1)
66+
@test size(α) == (3, 5, 2, 1)
6767
@test size(y) == (4, 5, 1)
6868
end
6969

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ include("test_utils.jl")
2929
if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true"
3030
using AMDGPU
3131
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
32+
AMDGPU.versioninfo()
33+
@show AMDGPU.MIOpen.version()
3234
@testset "AMDGPU" begin
3335
include("amd/runtests.jl")
3436
end

0 commit comments

Comments
 (0)