Skip to content

Commit 7b44f9e

Browse files
Merge pull request #453 from JuliaArrays/restructure_tracker
Fix Tracker with restructure
2 parents 0042e23 + 00cc8c6 commit 7b44f9e

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

ext/ArrayInterfaceReverseDiffExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,8 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N
1515
end
1616
end
1717

18+
function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray)
19+
reshape(y, Base.size(x)...)
20+
end
21+
1822
end # module

ext/ArrayInterfaceTrackerExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false
99
ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false
1010
ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x)
1111

12+
function ArrayInterface.restructure(x::Array, y::Tracker.TrackedArray)
13+
reshape(y, Base.size(x)...)
14+
end
15+
function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal})
16+
reshape(y, Base.size(x)...)
17+
end
18+
1219
end # module

test/ad.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,21 @@ x = Tracker.TrackedArray([4.0,4.0])
1818
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
1919
x = [x[1],x[2]]
2020
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
21+
22+
x = rand(4)
23+
y = Tracker.TrackedReal.(rand(2,2))
24+
@test ArrayInterface.restructure(x, y) isa Array
25+
@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal
26+
@test size(ArrayInterface.restructure(x, y)) == (4,)
27+
y = Tracker.TrackedArray(rand(2,2))
28+
@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray
29+
@test size(ArrayInterface.restructure(x, y)) == (4,)
30+
31+
x = rand(4)
32+
y = ReverseDiff.track(rand(2,2))
33+
@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray
34+
@test size(ArrayInterface.restructure(x, y)) == (4,)
35+
y = ReverseDiff.track.(rand(2,2))
36+
@test ArrayInterface.restructure(x, y) isa Array
37+
@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal
38+
@test size(ArrayInterface.restructure(x, y)) == (4,)

0 commit comments

Comments
 (0)