1
- using StaticArrays: SVector
1
+ using StaticArrays: SVector, StaticArray
2
2
3
3
partial_type (:: Dual{T,V,P} ) where {T,V,P} = P
4
4
5
5
struct DualArray{T,E,M,V<: AbstractArray ,D<: AbstractArray } <: AbstractArray{E,M}
6
6
data:: V
7
7
partials:: D
8
8
function DualArray {T} (v:: AbstractArray{E,N} , p:: P ) where {T,E,N,P<: AbstractArray }
9
- X = typeof (vec (p))
9
+ X = p isa StaticArray ? typeof ( vec (p)[ axes (p, ndims (p))]) : typeof (vec (p))
10
10
# we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
11
11
# some kind of `view`, because we can convert `SubArray` to `Array` but
12
12
# not vise a versa.
@@ -36,7 +36,7 @@ function Base.print_array(io::IO, da::DualArray)
36
36
end
37
37
38
38
DualArray (a:: AbstractArray , b:: AbstractArray ) = DualArray {typeof(dualtag())} (a, b)
39
- npartials (d:: DualArray ) = size (d . partials , ndims (d . partials ))
39
+ npartials (d:: DualArray ) = (ps = allpartials (d); size (ps , ndims (ps) ))
40
40
data (d:: DualArray ) = d. data
41
41
allpartials (d:: DualArray ) = d. partials
42
42
@@ -50,8 +50,6 @@ Base.IndexStyle(d::DualArray) = Base.IndexStyle(data(d))
50
50
Base. similar (d:: DualArray{T} , :: Type{S} , dims:: Dims ) where {T, S} = DualArray {T} (similar (data (d)), similar (allpartials (d)))
51
51
Base. eachindex (d:: DualArray ) = eachindex (data (d))
52
52
53
- using StaticArrays
54
-
55
53
Base. @propagate_inbounds _slice (A, i... ) = @view A[i... , :]
56
54
Base. @propagate_inbounds _slice (A:: StaticArray , i... ) = A[i... , :]
57
55
0 commit comments