@@ -7,10 +7,16 @@ partial_type(::Dual{T,V,P}) where {T,V,P} = P
7
7
struct DualArray{T,E,M,V<: AbstractArray ,D<: AbstractArray } <: AbstractArray{E,M}
8
8
data:: V
9
9
partials:: D
10
- function DualArray {T} (v:: AbstractArray{E,N} , p:: AbstractArray{P,M} ) where {T,E,N,P,M}
11
- # TODO : Fix the empty array case
12
- VT = typeof (_slice (p, Base. tail (ntuple (one, Val {M} ()))... ))
13
- return new {T,Dual{T,E,VT},N,typeof(v),typeof(p)} (v, p)
10
+ function DualArray {T} (v:: AbstractArray{E,N} , p:: P ) where {T,E,N,P<: AbstractArray }
11
+ # TODO : non-allocating X?
12
+ X = typeof (similar (p, Base. tail (ntuple (_-> 0 , Val (ndims (P))))))
13
+ # we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
14
+ # some kind of `view`, because we can convert `SubArray` to `Array` but
15
+ # not vise a versa.
16
+ #
17
+ # We need that to differentiate through the following code
18
+ # `(foo(x::AbstractArray{T})::T) where {T} = x[1]`
19
+ return new {T,Dual{T,E,X},N,typeof(v),typeof(p)} (v, p)
14
20
end
15
21
end
16
22
@@ -41,7 +47,6 @@ allpartials(d::DualArray) = d.partials
41
47
# ## Array interface
42
48
# ##
43
49
44
- Base. eltype (d:: DualArray{T,E} ) where {T,E} = Dual{T,E,npartials (d)}
45
50
# droplast(d::Tuple) = d |> reverse |> Base.tail |> reverse
46
51
Base. size (d:: DualArray ) = size (data (d))
47
52
Base. IndexStyle (d:: DualArray ) = Base. IndexStyle (data (d))
0 commit comments