Skip to content

Commit d6f380a

Browse files
authored
fix: aos_to_soa for all singleton dims
1 parent 2643104 commit d6f380a

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

ext/ArrayInterfaceReverseDiffExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false
88
ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false
99
ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false
1010
function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N}
11-
if length(x) > 1
12-
return reshape(reduce(vcat, x), size(x))
13-
else
14-
return reduce(vcat,[x[1], x[1]])[1:1]
15-
end
11+
y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1]
12+
return reshape(y, size(x))
1613
end
1714

1815
function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray)

0 commit comments

Comments
 (0)