@@ -96,14 +96,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
96
96
@testset " Correct definitions" begin
97
97
local inplace_used
98
98
function ChainRulesCore. frule ((_, ẋ), :: typeof (identity), x:: Array )
99
- ẏ = InplaceableThunk (@thunk (ẋ), ȧ -> (inplace_used = true ; ȧ .+ = ẋ))
99
+ ẏ = InplaceableThunk (ȧ -> (inplace_used = true ; ȧ .+ = ẋ), @thunk ( ẋ))
100
100
return identity (x), ẏ
101
101
end
102
102
function ChainRulesCore. rrule (:: typeof (identity), x:: Array )
103
103
function identity_pullback (ȳ)
104
- x̄_ret = InplaceableThunk (
105
- @thunk (ȳ), ā -> (inplace_used = true ; ā .+ = ȳ)
106
- )
104
+ x̄_ret = InplaceableThunk (@thunk (ȳ)) do ā
105
+ inplace_used = true
106
+ ā .+ = ȳ
107
+ end
107
108
return (NoTangent (), x̄_ret)
108
109
end
109
110
return identity (x), identity_pullback
@@ -122,14 +123,14 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
122
123
my_identity (value) = value # we will define bad rules on this
123
124
function ChainRulesCore. frule ((_, ẋ), :: typeof (my_identity), x:: Array )
124
125
# only the in-place part is incorrect
125
- ẏ = InplaceableThunk (@thunk (ẋ), ȧ -> ȧ .+ = 200 .* ẋ)
126
+ ẏ = InplaceableThunk (ȧ -> ȧ .+ = 200 .* ẋ, @thunk (ẋ) )
126
127
return my_identity (x), ẏ
127
128
end
128
129
function ChainRulesCore. rrule (:: typeof (my_identity), x:: Array )
129
130
x_dims = size (x)
130
131
function my_identity_pullback (ȳ)
131
132
# only the in-place part is incorrect
132
- x̄_ret = InplaceableThunk (@thunk (ȳ), ā -> ā .+ = 200 .* ȳ)
133
+ x̄_ret = InplaceableThunk (ā -> ā .+ = 200 .* ȳ, @thunk (ȳ) )
133
134
return (NoTangent (), x̄_ret)
134
135
end
135
136
return my_identity (x), my_identity_pullback
0 commit comments