Skip to content

Commit e702e32

Browse files
authored
Merge pull request #254 from nmheim/tests
Start testing reverse mode
2 parents 5b03343 + 6c9048c commit e702e32

File tree

5 files changed

+647
-20
lines changed

5 files changed

+647
-20
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ StructArrays = "0.6"
2828
julia = "1.10"
2929

3030
[extras]
31+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
32+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3133
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3234
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3335
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3436
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3537
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3638

3739
[targets]
38-
test = ["ForwardDiff", "LinearAlgebra", "Random", "Symbolics", "Test"]
40+
test = ["Distributed", "FiniteDifferences", "ForwardDiff", "LinearAlgebra", "Random", "Symbolics", "Test"]

src/stage1/generated.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,12 @@ function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) whe
320320
(@Base.constprop :aggressive Δ->begin
321321
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
322322
BB = zero(a)
323-
BB[inds...] = unthunk(Δ)
323+
324+
# view is needed to cover cases with duplicated indices like
325+
# gradient(sum ∘ x -> x[:,[1,1,2]], rand(3,4))
326+
# https://github.com/JuliaDiff/Diffractor.jl/pull/254#discussion_r1480791644
327+
view(BB, inds...) .+= unthunk(Δ)
328+
324329
(NoTangent(), BB, map(x->NoTangent(), inds)...)
325330
end),
326331
(@Base.constprop :aggressive (_, Δ, _)->begin

0 commit comments

Comments
 (0)