Skip to content

Commit e86af58

Browse files
committed
Fix StaticArrays' X computation
1 parent 58c3f88 commit e86af58

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/dualarray.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
using StaticArrays: SVector
1+
using StaticArrays: SVector, StaticArray
22

33
partial_type(::Dual{T,V,P}) where {T,V,P} = P
44

55
struct DualArray{T,E,M,V<:AbstractArray,D<:AbstractArray} <: AbstractArray{E,M}
66
data::V
77
partials::D
88
function DualArray{T}(v::AbstractArray{E,N}, p::P) where {T,E,N,P<:AbstractArray}
9-
X = typeof(vec(p))
9+
X = p isa StaticArray ? typeof(vec(p)[axes(p, ndims(p))]) : typeof(vec(p))
1010
# we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
1111
# some kind of `view`, because we can convert `SubArray` to `Array` but
1212
# not vise a versa.
@@ -36,7 +36,7 @@ function Base.print_array(io::IO, da::DualArray)
3636
end
3737

3838
DualArray(a::AbstractArray, b::AbstractArray) = DualArray{typeof(dualtag())}(a, b)
39-
npartials(d::DualArray) = size(d.partials, ndims(d.partials))
39+
npartials(d::DualArray) = (ps = allpartials(d); size(ps, ndims(ps)))
4040
data(d::DualArray) = d.data
4141
allpartials(d::DualArray) = d.partials
4242

@@ -50,8 +50,6 @@ Base.IndexStyle(d::DualArray) = Base.IndexStyle(data(d))
5050
Base.similar(d::DualArray{T}, ::Type{S}, dims::Dims) where {T, S} = DualArray{T}(similar(data(d)), similar(allpartials(d)))
5151
Base.eachindex(d::DualArray) = eachindex(data(d))
5252

53-
using StaticArrays
54-
5553
Base.@propagate_inbounds _slice(A, i...) = @view A[i..., :]
5654
Base.@propagate_inbounds _slice(A::StaticArray, i...) = A[i..., :]
5755

0 commit comments

Comments
 (0)