Skip to content

Commit 72bb760

Browse files
committed
Fix extract_diffresult
1 parent b1a26e0 commit 72bb760

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/api.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
using StaticArrays: StaticArray, SMatrix, SVector
22
using LinearAlgebra: Diagonal, I
33

4-
tilt(xs::AbstractArray{<:Number}) = xs'
5-
tilt(xs) = hcat(xs...)
6-
tilt(xs::StaticArray{<:Number}) = xs'
7-
function tilt(xs::StaticArray)
8-
# Jacobian
4+
extract_diffresult(xs::AbstractArray{<:Number}) = xs
5+
# need to optimize
6+
extract_diffresult(xs) = hcat(xs...)'
7+
function extract_diffresult(xs::StaticVector{<:StaticArray})
98
tup = reduce((x,y)->tuple(x..., y...), map(x->x.data, xs.data))
10-
SMatrix{length(xs), length(xs[1])}(tup)'
9+
SMatrix{length(xs), length(xs[1])}(tup)
1110
end
11+
extract_diffresult(xs::AbstractMatrix{<:Number}) = xs
12+
extract_diffresult(xs::AbstractVector{<:Number}) = xs'
1213

1314
allpartials(xs) = map(partials, xs)
1415

@@ -24,11 +25,10 @@ function D(f)
2425
# grad
2526
function deriv(arg::AbstractArray)
2627
# always chunk
27-
res = dualrun() do
28-
darr = DualArray(arg, seed(arg))
29-
f(darr)
30-
end
31-
tilt(allpartials(res))
28+
darr = dualrun(()->DualArray(arg, seed(arg)))
29+
res = dualrun(()->f(darr))
30+
diffres = extract_diffresult(allpartials(res))
31+
return diffres
3232
end
3333
# scalar
3434
function deriv(arg)

0 commit comments

Comments
 (0)