Skip to content

Commit 8d12336

Browse files
devmotionjrevels
authored andcommitted
Dispatch on StaticArray instead of SArray (#7)
1 parent d49bd6d commit 8d12336

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/DiffResults.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ Return `r::DiffResult`, with output value storage provided by `value` and output
3939
storage provided by `derivs`.
4040
4141
In reality, `DiffResult` is an abstract supertype of two concrete types, `MutableDiffResult`
42-
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `SArray`s, then `r`
43-
will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
42+
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `StaticArray`s,
43+
then `r` will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
4444
(i.e. `r::MutableDiffResult`).
4545
4646
Note that `derivs` can be provide in splatted form, i.e. `DiffResult(value, derivs...)`.
4747
"""
4848
DiffResult
4949

5050
DiffResult(value::Number, derivs::Tuple{Vararg{Number}}) = ImmutableDiffResult(value, derivs)
51-
DiffResult(value::Number, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
52-
DiffResult(value::SArray, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
51+
DiffResult(value::Number, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs)
52+
DiffResult(value::StaticArray, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs)
5353
DiffResult(value::Number, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
5454
DiffResult(value::AbstractArray, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
5555
DiffResult(value::Union{Number,AbstractArray}, derivs::Union{Number,AbstractArray}...) = DiffResult(value, derivs)
@@ -65,7 +65,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
6565
constructor instead.
6666
"""
6767
GradientResult(x::AbstractArray) = DiffResult(first(x), similar(x))
68-
GradientResult(x::SArray) = DiffResult(first(x), x)
68+
GradientResult(x::StaticArray) = DiffResult(first(x), x)
6969

7070
"""
7171
JacobianResult(x::AbstractArray)
@@ -79,7 +79,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
7979
constructor instead.
8080
"""
8181
JacobianResult(x::AbstractArray) = DiffResult(similar(x), similar(x, length(x), length(x)))
82-
JacobianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(x, zeros(SMatrix{L,L,T}))
82+
JacobianResult(x::StaticArray) = DiffResult(x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x)))))
8383

8484
"""
8585
JacobianResult(y::AbstractArray, x::AbstractArray)
@@ -92,7 +92,7 @@ Like the single argument version, `y` and `x` are only used for type and
9292
shape information and are not stored in the returned `DiffResult`.
9393
"""
9494
JacobianResult(y::AbstractArray, x::AbstractArray) = DiffResult(similar(y), similar(y, length(y), length(x)))
95-
JacobianResult(y::SArray{<:Any,<:Any,<:Any,Y}, x::SArray{<:Any,T,<:Any,X}) where {T,Y,X} = DiffResult(y, zeros(SMatrix{Y,X,T}))
95+
JacobianResult(y::StaticArray, x::StaticArray) = DiffResult(y, zeros(StaticArrays.similar_type(typeof(x), Size(length(y),length(x)))))
9696

9797
"""
9898
HessianResult(x::AbstractArray)
@@ -105,7 +105,7 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
105105
constructor instead.
106106
"""
107107
HessianResult(x::AbstractArray) = DiffResult(first(x), similar(x), similar(x, length(x), length(x)))
108-
HessianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(first(x), x, zeros(SMatrix{L,L,T}))
108+
HessianResult(x::StaticArray) = DiffResult(first(x), x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x)))))
109109

110110
#############
111111
# Interface #
@@ -203,7 +203,7 @@ function derivative!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Va
203203
return r
204204
end
205205

206-
function derivative!(r::ImmutableDiffResult, x::Union{Number,SArray}, ::Type{Val{i}} = Val{1}) where {i}
206+
function derivative!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i}
207207
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i}))
208208
end
209209

@@ -232,7 +232,7 @@ function derivative!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{
232232
return derivative!(r, f(x), Val{i})
233233
end
234234

235-
function derivative!(f, r::ImmutableDiffResult, x::SArray, ::Type{Val{i}} = Val{1}) where {i}
235+
function derivative!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i}
236236
return derivative!(r, map(f, x), Val{i})
237237
end
238238

0 commit comments

Comments
 (0)