Skip to content

Commit 0acd847

Browse files
Fix Tracker with restructure
1 parent 0042e23 commit 0acd847

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

ext/ArrayInterfaceTrackerExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false
99
ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false
1010
ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x)
1111

12+
function ArrayInterface.restructure(x::Array, y::TrackedArray)
13+
reshape(y, Base.size(x)...)
14+
end
15+
function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal})
16+
reshape(y, Base.size(x)...)
17+
end
18+
1219
end # module

test/ad.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ x = Tracker.TrackedArray([4.0,4.0])
1818
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
1919
x = [x[1],x[2]]
2020
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
21+
22+
x = rand(4)
23+
y = Tracker.TrackedReal.(rand(2,2))
24+
@test ArrayInterface.restructure(x, y) isa Array
25+
@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal
26+
@test size(ArrayInterface.restructure(x, y)) == (4,)
27+
y = Tracker.TrackedArray(rand(2,2))
28+
@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray
29+
@test size(ArrayInterface.restructure(x, y)) == (4,)

0 commit comments

Comments
 (0)