File tree Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Original file line number Diff line number Diff line change 1
1
using StaticArrays: StaticArray, SMatrix, SVector
2
- using LinearAlgebra: Diagonal, I
2
+ using LinearAlgebra: I
3
3
4
4
extract_diffresult (xs:: AbstractArray{<:Number} ) = xs
5
5
# need to optimize
@@ -17,15 +17,24 @@ function seed(v::SVector{N}) where N
17
17
SMatrix {N,N,eltype(v)} (I)
18
18
end
19
19
20
+ function _seed (v, ij)
21
+ i, j = Tuple (ij)
22
+ vi = v[i]
23
+ return (i== j) ? one (vi) : zero (vi)
24
+ end
25
+
20
26
function seed (v)
21
- Matrix (Diagonal (map (one, v)))
27
+ vv = vec (v)
28
+ ax = axes (vv, 1 )
29
+ return _seed .(Ref (vv), CartesianIndices ((ax, ax)))
22
30
end
23
31
24
32
function D (f)
25
33
# grad
26
34
function deriv (arg:: AbstractArray )
27
35
# always chunk
28
- darr = dualrun (()-> DualArray (arg, seed (arg)))
36
+ arg_partial = seed (arg)
37
+ darr = dualrun (()-> DualArray (arg, arg_partial))
29
38
res = dualrun (()-> f (darr))
30
39
diffres = extract_diffresult (allpartials (res))
31
40
return diffres
You can’t perform that action at this time.
0 commit comments