Skip to content

Commit b1a26e0

Browse files
committed
Fix error with cumsum
1 parent 3b6f509 commit b1a26e0

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/dualarray.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ partial_type(::Dual{T,V,P}) where {T,V,P} = P
77
struct DualArray{T,E,M,V<:AbstractArray,D<:AbstractArray} <: AbstractArray{E,M}
88
data::V
99
partials::D
10-
function DualArray{T}(v::AbstractArray{E,N}, p::AbstractArray{P,M}) where {T,E,N,P,M}
11-
# TODO: Fix the empty array case
12-
VT = typeof(_slice(p, Base.tail(ntuple(one, Val{M}()))...))
13-
return new{T,Dual{T,E,VT},N,typeof(v),typeof(p)}(v, p)
10+
function DualArray{T}(v::AbstractArray{E,N}, p::P) where {T,E,N,P<:AbstractArray}
11+
# TODO: non-allocating X?
12+
X = typeof(similar(p, Base.tail(ntuple(_->0, Val(ndims(P))))))
13+
# we need the eltype of `DualArray` to be `Dual{T,E,X}` as opposed to
14+
# some kind of `view`, because we can convert `SubArray` to `Array` but
15+
# not vise a versa.
16+
#
17+
# We need that to differentiate through the following code
18+
# `(foo(x::AbstractArray{T})::T) where {T} = x[1]`
19+
return new{T,Dual{T,E,X},N,typeof(v),typeof(p)}(v, p)
1420
end
1521
end
1622

@@ -41,7 +47,6 @@ allpartials(d::DualArray) = d.partials
4147
### Array interface
4248
###
4349

44-
Base.eltype(d::DualArray{T,E}) where {T,E} = Dual{T,E,npartials(d)}
4550
#droplast(d::Tuple) = d |> reverse |> Base.tail |> reverse
4651
Base.size(d::DualArray) = size(data(d))
4752
Base.IndexStyle(d::DualArray) = Base.IndexStyle(data(d))

test/api.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ using StaticArrays
77
@test D(x->[x, x^2])(3) == [1, 6]
88
@test D(sum)([1,2,3]) == ones(3)'
99
@test D(x->@SVector([x[1]^x[2], x[3]^3, x[3]*x[2]*x[1]]))(@SVector[1,2,3.]) === @SMatrix [2.0 0 6; 0 0 3; 0 27 2]
10+
@test D(cumsum)(@SVector([1,2,3])) == @SMatrix [1 0 0; 1 1 0; 1 1 1]
11+
@test D(cumsum)([1,2,3]) == [1 0 0; 1 1 0; 1 1 1]
1012
end

0 commit comments

Comments
 (0)