Skip to content

Commit dbe7765

Browse files
committed
Eagerly evaluate scalers rules
Master behavior ```julia julia> @scalar_rule(one(x), Zero()) julia> frule(one, 1, Zero(), [1, 2]) (1, Zero()) julia> frule(one, 1, Zero(), One()) (1, Zero()) ``` Desirable behavior ```julia julia> @scalar_rule(one(x), Zero()) julia> frule(one, 1, Zero(), [1, 2]) (1, [0, 0]) julia> frule(one, 1, Zero(), One()) (1, Thunk(var"#8#10"()) ) ```
1 parent 7704d22 commit dbe7765

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/differential_arithmetic.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@ end
5757
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
5858
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
5959
for T in (:Any,)
60-
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
61-
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)
60+
# we want to eagerly compute the result when thunk meets other types
61+
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
62+
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)
6263

63-
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
64-
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
64+
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
65+
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
6566
end
6667

6768
################## Composite ##############################################################

src/rule_definition_tools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ function propagation_expr(Δs, ∂s)
209209
# This is basically Δs ⋅ ∂s
210210
∂s = map(esc, ∂s)
211211

212-
∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), length(∂s))
212+
# this is neccssary since we want to eagerly evaluate the result
213+
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
213214
return :(+($(∂_mul_Δs...)))
214215
end
215216

0 commit comments

Comments
 (0)