1
1
using StaticArrays: SVector
2
+
3
+ partial_type (:: Dual{T,V,P} ) where {T,V,P} = P
2
4
# TODO : Tagging?
3
5
# TODO : Integrate better with SVector. Maybe even use SIMD.jl?
4
6
5
- struct DualArray{T,E,M,D<: AbstractArray ,I} <: AbstractArray{E,M}
6
- data:: D
7
- DualArray {T,I} (a:: AbstractArray{E,N} ) where {T,I,E,N} = new {T,E,N-1,typeof(a),I} (a)
7
+ struct DualArray{T,E,M,V<: AbstractArray ,D<: AbstractArray } <: AbstractArray{E,M}
8
+ data:: V
9
+ partials:: D
10
+ function DualArray {T} (v:: AbstractArray{E,N} , p:: P ) where {T,E,N,P<: AbstractArray }
11
+ # TODO : non-allocating X?
12
+ X = typeof (similar (p, Base. tail (ntuple (_-> 0 , Val (ndims (P))))))
13
+ # we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
14
+ # some kind of `view`, because we can convert `SubArray` to `Array` but
15
+ # not vise a versa.
16
+ #
17
+ # We need that to differentiate through the following code
18
+ # `(foo(x::AbstractArray{T})::T) where {T} = x[1]`
19
+ return new {T,Dual{T,E,X},N,typeof(v),typeof(p)} (v, p)
20
+ end
8
21
end
9
22
10
23
# ##
@@ -18,91 +31,39 @@ function Base.print_array(io::IO, da::DualArray)
18
31
Base. printstyled (io, " Primals:\n " , bold= false , color= 2 )
19
32
prev_params = io isa IOContext ? io. dict : ()
20
33
ioc = IOContext (io, prev_params... , sz)
21
- Base. print_array (ioc, value (da))
34
+ Base. print_array (ioc, data (da))
22
35
Base. println (io)
23
- for i= 1 : npartials (da)
24
- Base. printstyled (io," Partials($i ):\n " , bold= false , color= 3 )
25
- Base. print_array (ioc, getindex .(partials .(da), i))
26
- i != = npartials (da) && Base. println (io)
27
- end
36
+ Base. printstyled (io," Partials:\n " , bold= false , color= 3 )
37
+ Base. print_array (ioc, allpartials (da))
28
38
return nothing
29
39
end
30
40
31
- DualArray (a:: AbstractArray ) = DualArray {Nothing,size(a, ndims(a))-1} (a)
32
- npartials (d:: DualArray{T,E,M,D,I} ) where {T,E,M,D,I} = I
33
- tagtype (:: Type{<:DualArray{T}} ) where {T} = T
34
- tagtype (:: T ) where {T<: DualArray } = tagtype (T)
41
+ DualArray (a:: AbstractArray , b:: AbstractArray ) = DualArray {typeof(dualtag())} (a, b)
42
+ npartials (d:: DualArray ) = size (d. partials, ndims (d. partials))
35
43
data (d:: DualArray ) = d. data
44
+ allpartials (d:: DualArray ) = d. partials
36
45
37
46
# ##
38
47
# ## Array interface
39
48
# ##
40
49
41
- Base. eltype (d:: DualArray{T,E} ) where {T,E} = Dual{T,E,npartials (d)}
42
- droplast (d:: Tuple ) = d |> reverse |> Base. tail |> reverse
43
- Base. size (d:: DualArray ) = size (data (d)) |> droplast
44
- Base. size (d:: DualArray , i) = i <= ndims (d) ? size (data (d), i) : 1
50
+ # droplast(d::Tuple) = d |> reverse |> Base.tail |> reverse
51
+ Base. size (d:: DualArray ) = size (data (d))
45
52
Base. IndexStyle (d:: DualArray ) = Base. IndexStyle (data (d))
46
- Base. strides (d:: DualArray ) = strides ( data (d)) |> droplast
47
- Base. similar (d:: DualArray{T} , :: Type{S} , dims :: Dims ) where {T,S} = DualArray {T,npartials(d)} ( similar ( data (d), S, (dims ... , npartials (d) + 1 ) ))
53
+ Base. similar (d:: DualArray{T} , :: Type{S} , dims :: Dims ) where {T, S} = DualArray {T} ( similar ( data (d)), similar ( allpartials (d)))
54
+ Base. eachindex (d:: DualArray ) = eachindex ( data (d))
48
55
49
- # ##
50
- # ## Broadcast interface
51
- # ##
56
+ using StaticArrays
52
57
53
- using Base. Broadcast: Broadcasted, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle
58
+ Base. @propagate_inbounds _slice (A, i... ) = @view A[i... , :]
59
+ Base. @propagate_inbounds _slice (A:: StaticArray , i... ) = A[i... , :]
54
60
55
- struct DualStyle{M,T,I,D} <: AbstractArrayStyle{M}
61
+ Base. @propagate_inbounds function Base. getindex (d:: DualArray{T} , i:: Int... ) where {T}
62
+ return Dual {T} (data (d)[i... ], _slice (allpartials (d), i... ))
56
63
end
57
64
58
- Base. Broadcast. BroadcastStyle (:: Type{<:DualArray{T,E,M,D,I}} ) where {T,E,M,D,I} = DualStyle {M,T,I,typeof(Base.Broadcast.BroadcastStyle(D))} ()
59
- function Base. similar (bc:: Broadcasted{<:DualStyle{M,T,I,D}} , :: Type{E} ) where {M,T,I,D,E}
60
- if E <: Dual
61
- V = valtype (E)
62
- arr = DualArray {T,I} (similar (Array{V}, (axes (bc)... , Base. OneTo (I+ 1 )))) # TODO : work with arbitrary array types. Maybe use `ArrayInterface.jl`?
63
- else
64
- bc′ = convert (Broadcasted{D}, bc)
65
- arr = Base. similar (bc′, E)
66
- end
67
- return arr
68
- end
69
- Base. BroadcastStyle (:: DualStyle{M,T,I,D} , :: DualStyle{M,T,I,V} ) where {M,T,I,D,V} = DualStyle {M,T,I,typeof(Base.BroadcastStyle(D(), V()))} ()
70
- Base. BroadcastStyle (:: DualStyle{M,T,I,D} , B:: BroadcastStyle ) where {M,T,I,D} = DualStyle {M,T,I,typeof(Base.BroadcastStyle(D(), B))} ()
71
- Base. BroadcastStyle (:: DualStyle{M,T,I,D} , B:: DefaultArrayStyle ) where {M,T,I,D} = DualStyle {M,T,I,typeof(Base.BroadcastStyle(D(), B))} ()
72
-
73
- function value (d:: DualArray )
74
- n = ndims (d)
75
- dd = data (d)
76
- return @view dd[ntuple (_ -> Colon (), Val (n))... , 1 ]
77
- end
78
-
79
- function partials (d:: DualArray )
80
- n = ndims (d)
81
- dd = data (d)
82
- return @view dd[ntuple (_ -> Colon (), Val (n))... , 2 : end ]
83
- end
84
-
85
- Base. eachindex (d:: DualArray ) = eachindex (@view data (d)[:, 1 ])
86
-
87
- Base. @propagate_inbounds function Base. getindex (d:: DualArray , i:: Int... )
88
- dd = data (d)
89
- # TODO : do something different if dd is not Linear index style
90
- ii = LinearIndices (size (d))[i... ]
91
- val = dd[ii]
92
- slice_len = length (d)
93
- parts = ntuple (j-> dd[j * slice_len + ii], Val (npartials (d)))
94
- return Dual (val, SVector (parts))
95
- end
96
-
97
- Base. @propagate_inbounds function Base. setindex! (d:: DualArray , dual, i:: Int... )
98
- dd = data (d)
99
- ii = LinearIndices (size (d))[i... ]
100
- dd[ii] = value (dual)
101
-
102
- slice_len = length (d)
103
- ps = partials (dual)
104
- for j = 1 : npartials (d)
105
- dd[j * slice_len + ii] = ps[j]
106
- end
65
+ Base. @propagate_inbounds function Base. setindex! (d:: DualArray{T} , dual:: Dual{T} , i:: Int... ) where {T}
66
+ data (d)[i... ] = value (dual)
67
+ allpartials (d)[i... , :] .= partials (dual)
107
68
return dual
108
69
end
0 commit comments