Skip to content

Commit 3fd3205

Browse files
committed
Better seeding for AbstractArrays
1 parent 0fdd3b4 commit 3fd3205

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/api.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using StaticArrays: StaticArray, SMatrix, SVector
2-
using LinearAlgebra: Diagonal, I
2+
using LinearAlgebra: I
33

44
extract_diffresult(xs::AbstractArray{<:Number}) = xs
55
# need to optimize
@@ -17,15 +17,24 @@ function seed(v::SVector{N}) where N
1717
SMatrix{N,N,eltype(v)}(I)
1818
end
1919

20+
function _seed(v, ij)
21+
i, j = Tuple(ij)
22+
vi = v[i]
23+
return (i==j) ? one(vi) : zero(vi)
24+
end
25+
2026
function seed(v)
21-
Matrix(Diagonal(map(one, v)))
27+
vv = vec(v)
28+
ax = axes(vv, 1)
29+
return _seed.(Ref(vv), CartesianIndices((ax, ax)))
2230
end
2331

2432
function D(f)
2533
# grad
2634
function deriv(arg::AbstractArray)
2735
# always chunk
28-
darr = dualrun(()->DualArray(arg, seed(arg)))
36+
arg_partial = seed(arg)
37+
darr = dualrun(()->DualArray(arg, arg_partial))
2938
res = dualrun(()->f(darr))
3039
diffres = extract_diffresult(allpartials(res))
3140
return diffres

0 commit comments

Comments
 (0)