diff --git a/Project.toml b/Project.toml index 20984db45..4c3b7433c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.16.0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" [compat] BenchmarkTools = "0.5" @@ -13,6 +14,8 @@ Compat = "2, 3, 4" FiniteDifferences = "0.10" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" +GPUArraysCore = "0.1" +JLArrays = "0.1" julia = "1.6" [extras] @@ -20,7 +23,15 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "StaticArrays"] +test = [ + "Test", + "BenchmarkTools", + "FiniteDifferences", + "OffsetArrays", + "StaticArrays", + "JLArrays", +] diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index b81ab4fba..864f06289 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -3,6 +3,7 @@ using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, mate using Base.Meta using LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC +using GPUArraysCore using Compat: hasfield, hasproperty export frule, rrule # core function diff --git a/src/projection.jl b/src/projection.jl index 811802536..6c58cc0e2 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -227,7 +227,13 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} throw(_projection_mismatch(project.axes, size(dx))) end end - reshape(dx, project.axes) + # Reshape, copying to remove the wrapper if a GPUArray, see + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/624 + if dx isa AbstractGPUArray + copy(reshape(dx, project.axes)) + else + reshape(dx, project.axes) + end end # Then deal with the elements. One projector if AbstractArray{<:Number}, # or one per element for arrays of anything else, including arrays of arrays: @@ -385,7 +391,6 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) end end - ##### ##### `LinearAlgebra` ##### @@ -613,3 +618,46 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) invoke(project, Tuple{AbstractArray}, dx) end end + +##### +##### `GPUArrays` +##### + +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/624 + +# Row vectors aren't acceptable as gradients for 1-row matrices: +# Nested GPUArray wrappers lead to scalar indexing, try to prevent that: +function (project::ProjectTo{AbstractArray})( + dx::Transpose{T,A} +) where {T,A<:AbstractGPUVector} + return project(copy(reshape(vec(dx), 1, :))) +end +function (project::ProjectTo{AbstractArray})( + dx::Adjoint{T,A} +) where {T,A<:AbstractGPUVector} + return project(copy(reshape(conj(adjoint(dx)), 1, :))) +end + +# Make sure wrappers either cancel out or are materialized to maintain a maximum +# wrapper depth of 1: +AdjOrTransAbsGPUVec = Union{Adjoint{T,A},Transpose{T,A}} where {T,A<:AbstractGPUVector} +function (project::ProjectTo{Adjoint})(dx::AdjOrTransAbsGPUVec) + return adjoint(project.parent(conj(transpose(dx)))) +end +function (project::ProjectTo{Adjoint})(dx::AbstractGPUArray) + if size(dx, 1) != 1 || size(dx, 2) != length(project.parent.axes[1]) + throw(_projection_mismatch((1:1, project.parent.axes...), size(dx))) + end + dy = eltype(dx) <: Real ? copy(vec(dx)) : copy(adjoint(dx)) + return adjoint(project.parent(dy)) +end +function (project::ProjectTo{Transpose})(dx::AdjOrTransAbsGPUVec) + return transpose(project.parent(conj(adjoint(dx)))) +end +function (project::ProjectTo{Transpose})(dx::AbstractGPUArray) + if size(dx, 1) != 1 || size(dx, 2) != length(project.parent.axes[1]) + throw(_projection_mismatch((1:1, project.parent.axes...), size(dx))) + end + dy = eltype(dx) <: Number ? copy(vec(dx)) : copy(transpose(dx)) + return transpose(project.parent(dy)) +end diff --git a/test/projection.jl b/test/projection.jl index d364631fc..91511797b 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,6 +1,7 @@ using ChainRulesCore, Test using LinearAlgebra, SparseArrays using OffsetArrays, StaticArrays, BenchmarkTools +using JLArrays # Like ForwardDiff.jl's Dual struct Dual{T<:Real} <: Real @@ -50,7 +51,7 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa - Complex{<:Dual} + Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent @@ -143,7 +144,7 @@ struct NoSuperType end @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} - + @test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero end @@ -154,7 +155,7 @@ struct NoSuperType end @test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent @test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any} end - @test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector + @test pt1([1]) == Tangent{Tuple{Float64}}(1.0) # accepts Vector @test @inferred(pt1(NoTangent())) === NoTangent() @test @inferred(pt1(ZeroTangent())) === ZeroTangent() @test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero @@ -163,7 +164,9 @@ struct NoSuperType end @test_throws Exception pt1([]) pt3 = ProjectTo(([1, 2, 3], false, :gamma)) # partly non-differentiable - @test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int}, Bool, Symbol}}([1.0, 2.0, 3.0], NoTangent(), NoTangent()) + @test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int},Bool,Symbol}}( + [1.0, 2.0, 3.0], NoTangent(), NoTangent() + ) @test ProjectTo((true, [false])) isa ProjectTo{NoTangent} end @@ -216,7 +219,7 @@ struct NoSuperType end @testset "UniformScaling" begin @test ProjectTo(I)(123) === NoTangent() @test ProjectTo(2 * I)(I * 3im) === 0.0 * I - @test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ = 6)) === (6.0 + 0.0im) * I + @test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ=6)) === (6.0 + 0.0im) * I @test ProjectTo(7 * I)(Tangent{typeof(2I)}()) == ZeroTangent() end @@ -375,29 +378,93 @@ struct NoSuperType end pvec3 = ProjectTo([1, 2, 3]) @test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,) @test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test - @test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector + @test pvec3(OffsetArray(rand(3, 1), 0:2, 0:0)) isa Vector end ##### ##### `StaticArrays` ##### - @testset "StaticArrays" begin - # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) - # implies a check, and reshape will wrap a Vector into a static SizedVector: - pstat = ProjectTo(SA[1, 2, 3]) - @test axes(pstat(rand(3))) === (SOneTo(3),) - - # This recurses into structured arrays: - pst = ProjectTo(transpose(SA[1, 2, 3])) - @test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3)) - @test pst(rand(1,3)) isa Transpose - - # When the argument is an ordinary Array, static gradients are allowed to pass, - # like FillArrays. Collecting to an Array would cost a copy. - pvec3 = ProjectTo([1, 2, 3]) - @test pvec3(SA[1, 2, 3]) isa StaticArray + @testset "StaticArrays" begin + # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) + # implies a check, and reshape will wrap a Vector into a static SizedVector: + pstat = ProjectTo(SA[1, 2, 3]) + @test axes(pstat(rand(3))) === (SOneTo(3),) + + # This recurses into structured arrays: + pst = ProjectTo(transpose(SA[1, 2, 3])) + @test axes(pst(rand(1, 3))) === (SOneTo(1), SOneTo(3)) + @test pst(rand(1, 3)) isa Transpose + + # When the argument is an ordinary Array, static gradients are allowed to pass, + # like FillArrays. Collecting to an Array would cost a copy. + pvec3 = ProjectTo([1, 2, 3]) + @test pvec3(SA[1, 2, 3]) isa StaticArray + end + + ##### + ##### `GPU arrays` + ##### + + # issue #624 + @testset "GPUArrays" begin + JLVector = JLArray{T,1} where {T} + JLMatrix = JLArray{T,2} where {T} + + pvec3 = ProjectTo(JLArray([1, 2, 3])) + @test pvec3(JLArray(1.0:3.0)) == JLArray(1.0:3.0) + @test pvec3(JLArray(1:3)) == JLArray(1.0:3.0) # would prefer ===, map(Float64, dx) would do that, not important + @test pvec3(JLArray([1, 2, 3 + 4im])) == JLArray(1:3) + @test eltype(pvec3(JLArray([1, 2, 3.0f0]))) === Float64 + + # reshape + @test pvec3(reshape(JLArray([1, 2, 3]), 3, 1)) isa JLVector + @test_throws DimensionMismatch pvec3(reshape(JLArray([1, 2, 3]), 1, 3)) + @test_throws DimensionMismatch pvec3(JLArray([1, 2, 3, 4])) + + pmat = ProjectTo(JLArray(rand(2, 2) .+ im)) + @test pmat(JLArray([1 2; 3 4.0+5im])') isa Adjoint # pass-through + @test pmat(JLArray([1 2; 3 4])') isa JLMatrix # broadcast type change + + pmat2 = ProjectTo(JLArray(rand(2, 2))') + @test pmat2(JLArray([1 2; 3 4.0+5im])) isa JLMatrix # adjoint matrices are not re-created + + prow = ProjectTo(JLArray([1im 2 3im])) + @test prow(transpose(JLArray([1, 2, 3 + 4.0im]))) == JLArray([1 2 3 + 4im]) + @test prow(transpose(JLArray([1, 2, 3 + 4.0im]))) isa JLMatrix # row vectors may not pass through + @test prow(adjoint(JLArray([1, 2, 3 + 5im]))) == JLArray([1 2 3 - 5im]) + @test prow(adjoint(JLArray([1, 2, 3]))) isa JLMatrix + + # some bugs + @test pvec3(JLArray(fill(NoTangent(), 3))) === NoTangent() #410, was an array of such + @test ProjectTo(JLArray([pi]))(JLArray([1])) isa JLVector{Int} #423, was Irrational -> Bool -> NoTangent + + # adjoint vectors + @testset "GPUArrays: $adj vectors" for adj in [transpose, adjoint] + padj = ProjectTo(adj(JLArray([1, 2, 3]))) + adjT = typeof(adj(JLArray([1, 2, 3.0]))) + @test padj(transpose(JLArray(1:3))) isa adjT + @test padj(JLArray([4 5 6 + 7im])) isa adjT + @test padj(JLArray([4.0 5.0 6.0])) isa adjT + + @test_throws DimensionMismatch padj(JLArray([1, 2, 3])) + @test_throws DimensionMismatch padj(JLArray([1 2 3]')) + @test_throws DimensionMismatch padj(JLArray([1 2 3 4])) + + padj_complex = ProjectTo(adj(JLArray([1, 2, 3 + 4im]))) + @test padj_complex(JLArray([4 5 6 + 7im])) == JLArray([4 5 6 + 7im]) + @test padj_complex(transpose(JLArray([4, 5, 6 + 7im]))) == + JLArray([4 5 6 + 7im]) + @test padj_complex(adjoint(JLArray([4, 5, 6 + 7im]))) == JLArray([4 5 6 - 7im]) + + # issue #410 + @test padj(JLArray([NoTangent() NoTangent() NoTangent()])) === NoTangent() + + @test ProjectTo(adj(JLArray([true, false])))(JLArray([1 2])) isa AbstractZero + @test ProjectTo(adj([JLArray([true]), JLArray([false])])) isa + ProjectTo{<:AbstractZero} end + end ##### ##### `ChainRulesCore`