Skip to content

Commit a2ea087

Browse files
authored
always unthunk results (#79)
* always unthunk results * add unthunk for gradient
1 parent 82096ee commit a2ea087

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ end
127127
# N.B: This means the gradient is not available for zero-arg function, but such
128128
# a gradient would be guaranteed to be `()`, which is a bit of a useless thing
129129
function (::Type{∇})(f, x1, args...)
130-
(f)(x1, args...)
130+
unthunk.((f)(x1, args...))
131131
end
132132

133133
const gradient =
@@ -159,7 +159,7 @@ function (f::PrimeDerivativeBack)(x)
159159
z = ∂⃖¹(lower_pd(f), x)
160160
y = getfield(z, 1)
161161
f☆ = getfield(z, 2)
162-
return getfield(f☆(dx(y)), 2)
162+
return unthunk(getfield(f☆(dx(y)), 2))
163163
end
164164

165165
# Forwards primal derivative

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ let var"'" = Diffractor.PrimeDerivativeBack
108108
# Control flow cases
109109
@test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0)
110110
@test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0)
111-
@test_broken (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
111+
@test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
112112
@test times_three_while'(1.0) == 3.0
113113

114114
pow5p(x) = (x->mypow(x, 5))'(x)

0 commit comments

Comments
 (0)