Skip to content

Commit 00cc8c6

Browse files
Add ReverseDiff
1 parent c3acb74 commit 00cc8c6

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

ext/ArrayInterfaceReverseDiffExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,8 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N
1515
end
1616
end
1717

18+
function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray)
19+
reshape(y, Base.size(x)...)
20+
end
21+
1822
end # module

test/ad.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,13 @@ y = Tracker.TrackedReal.(rand(2,2))
2626
@test size(ArrayInterface.restructure(x, y)) == (4,)
2727
y = Tracker.TrackedArray(rand(2,2))
2828
@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray
29-
@test size(ArrayInterface.restructure(x, y)) == (4,)
29+
@test size(ArrayInterface.restructure(x, y)) == (4,)
30+
31+
x = rand(4)
32+
y = ReverseDiff.track(rand(2,2))
33+
@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray
34+
@test size(ArrayInterface.restructure(x, y)) == (4,)
35+
y = ReverseDiff.track.(rand(2,2))
36+
@test ArrayInterface.restructure(x, y) isa Array
37+
@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal
38+
@test size(ArrayInterface.restructure(x, y)) == (4,)

0 commit comments

Comments
 (0)