Skip to content

Commit 25b0be8

Browse files
Merge pull request #457 from avik-pal/patch-1
fix: aos_to_soa for all singleton dims
2 parents 2643104 + 1d93114 commit 25b0be8

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ext/ArrayInterfaceReverseDiffExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false
88
ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false
99
ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false
1010
function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N}
11-
if length(x) > 1
12-
return reshape(reduce(vcat, x), size(x))
13-
else
14-
return reduce(vcat,[x[1], x[1]])[1:1]
15-
end
11+
y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1]
12+
return reshape(y, size(x))
1613
end
1714

1815
function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray)

test/ad.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using ArrayInterface, ReverseDiff, Tracker, Test
22
x = ReverseDiff.track([4.0])
33
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
4+
x = reshape([ReverseDiff.track(rand(1, 1, 1))[1]], 1, 1, 1)
5+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
6+
@test ndims(ArrayInterface.aos_to_soa(x)) == 3
47
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
58
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
69
x = [ReverseDiff.track([4.0])[1]]

0 commit comments

Comments
 (0)