Skip to content

Commit 26c9acf

Browse files
authored
Patch Flux._isleaf for abstract arrays with bitstype elements (#2436)
* patch Flux._isleaf for Transpose and Adjoint arrays of bitstype elements * patch `Flux._isleaf(::PermutedDimsArray)` as well * `gradient(x -> sum(cpu(x)), ca')[1]` now returns `Adjoint{Float32, <:CuArray}` instead of plain `CuArray`
1 parent 9f580d9 commit 26c9acf

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

src/functor.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,13 @@ julia> m.bias
181181
"""
182182
cpu(x) = fmap(x -> adapt(FluxCPUAdaptor(), x), x, exclude = _isleaf)
183183

184-
_isbitsarray(::AbstractArray{<:Number}) = true
185-
_isbitsarray(::AbstractArray{T}) where T = isbitstype(T)
186-
_isbitsarray(x) = false
184+
_isleaf(x) = Functors.isleaf(x)
185+
186+
_isleaf(::AbstractArray{<:Number}) = true
187+
_isleaf(::AbstractArray{T}) where T = isbitstype(T)
188+
_isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false
187189

188190
_isleaf(::AbstractRNG) = true
189-
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
190191

191192
# the order below is important
192193
const GPU_BACKENDS = ("CUDA", "AMDGPU", "Metal", "CPU")

test/ext_cuda/cuda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ end
109109
# This test should really not go through indirections and pull out Fills for efficiency
110110
# but we forcefully materialise. TODO: remove materialising CuArray here
111111
@test gradient(x -> sum(cpu(x)), ca)[1] isa CuArray # This involves FillArray, which should be GPU compatible
112-
@test gradient(x -> sum(cpu(x)), ca')[1] isa CuArray
112+
@test gradient(x -> sum(cpu(x)), ca')[1] isa Adjoint{Float32, <:CuArray}
113113

114114
# Even more trivial: no movement
115115
@test gradient(x -> sum(abs, cpu(x)), a)[1] isa Matrix

test/utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,30 @@ end
567567
@test length(Flux.params(oneadj)) == 1 # needs Functors@0.3
568568

569569
@test Flux.destructure(simple)[1] == Flux.destructure(oneadj)[1] == [1, 3, 2, 4]
570+
571+
@testset "issue 2432" begin
572+
x = rand(1)
573+
m = (; a = x, b = x')
574+
count = Ref(0)
575+
mcopy = fmap(m; exclude = Flux._isleaf) do x
576+
count[] += 1
577+
return copy(x)
578+
end
579+
@test count[] == 1
580+
@test mcopy.a === mcopy.b'
581+
582+
struct BitsType
583+
x::Int32
584+
y::Float64
585+
end
586+
587+
for x in [1.0, 'a', BitsType(1, 2.0)]
588+
@test Flux._isleaf([x])
589+
@test !Flux._isleaf([x]')
590+
@test !Flux._isleaf(transpose([x]))
591+
@test !Flux._isleaf(PermutedDimsArray([x;;], (1, 2)))
592+
end
593+
end
570594
end
571595

572596
@testset "Various destructure bugs" begin

0 commit comments

Comments
 (0)