diff --git a/ext/ChainRulesCoreExt.jl b/ext/ChainRulesCoreExt.jl index ba5b46c5ca..347648fc9c 100644 --- a/ext/ChainRulesCoreExt.jl +++ b/ext/ChainRulesCoreExt.jl @@ -2,7 +2,7 @@ module ChainRulesCoreExt -using CUDA: CuArray +using CUDA: CuArray, CUDA isdefined(Base, :get_extension) ? (import ChainRulesCore) : (import ..ChainRulesCore) @@ -10,4 +10,7 @@ isdefined(Base, :get_extension) ? (import ChainRulesCore) : (import ..ChainRules ChainRulesCore.is_inplaceable_destination(::CuArray) = true +# allow usage of rand with Zygote +ChainRulesCore.@non_differentiable CUDA.randn(::Any...) + end diff --git a/test/Project.toml b/test/Project.toml index 5d6ea83e88..81ef845dfd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -25,3 +26,4 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/extensions/zygote.jl b/test/extensions/zygote.jl new file mode 100644 index 0000000000..8387051a28 --- /dev/null +++ b/test/extensions/zygote.jl @@ -0,0 +1,19 @@ +using GPUArraysCore: GPUArraysCore +using CUDA +using Zygote + +function call_rand(v::AbstractVector{T}) where {T} + randn(T, 4,4) * v[1:4] +end +function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T} + CUDA.randn(T, 4,4) * v[1:4] +end + +@testset "randn" begin + v_orig = collect(1.0f0:10.0f0) + mb = call_rand(v_orig) + v = CuArray(v_orig) + m = call_rand(v) + gr = Zygote.gradient(v -> sum(call_rand(v)), v) + @test gr[1:4] .!= 0 +end