diff --git a/src/derivatives/linalg/arithmetic.jl b/src/derivatives/linalg/arithmetic.jl index 271af22..f3532a7 100644 --- a/src/derivatives/linalg/arithmetic.jl +++ b/src/derivatives/linalg/arithmetic.jl @@ -33,6 +33,9 @@ for A in ARRAY_TYPES @eval @inline Base.:+(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_plus(x, y, D) end +@inline Base.:+(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_plus(x, Array(y), D) +@inline Base.:+(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_plus(Array(x), y, D) + function record_plus(x, y, ::Type{D}) where D tp = tape(x, y) out = track(value(x) + value(y), D, tp) @@ -108,6 +111,9 @@ for A in ARRAY_TYPES @eval Base.:-(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_minus(x, y, D) end +@inline Base.:-(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_minus(x, Array(y), D) +@inline Base.:-(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_minus(Array(x), y, D) + function Base.:-(x::TrackedArray{V,D}) where {V,D} tp = tape(x) out = track(-(value(x)), D, tp) diff --git a/src/tracked.jl b/src/tracked.jl index 049c1a9..08cf9bd 100644 --- a/src/tracked.jl +++ b/src/tracked.jl @@ -172,6 +172,8 @@ function deriv!(t::NTuple{N,Any}, v::NTuple{N,Any}) where N return nothing end +deriv!(t::StaticArray, v::AbstractArray) = deriv!(Tuple(t), Tuple(v)) + # pulling values from origin # #----------------------------# @@ -223,6 +225,8 @@ unseed!(x::AbstractArray, i) = unseed!(x[i]) capture(t::TrackedReal) = ifelse(hastape(t), t, value(t)) capture(t::TrackedArray) = t capture(t::AbstractArray) = istracked(t) ? map!(capture, similar(t), t) : copy(t) +# `StaticArray`s don't support mutation unless the eltype is a bits type (`isbitstype`). +capture(t::SA) where SA <: StaticArray = istracked(t) ? SA(map(capture, t)) : copy(t) ######################## # Conversion/Promotion #