Skip to content

Commit ba0eb6b

Browse files
committed
reorder argument order on InplaceableThink
1 parent 8b6527a commit ba0eb6b

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.10.4"
14+
ChainRulesCore = "0.10.12"
1515
Compat = "3"
1616
FiniteDifferences = "0.12.12"
1717
julia = "1"

test/check_result.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ end
1717
check(11.0, ZeroTangent())
1818
check([10.0, 20.0], @thunk([2.0, 0.0]))
1919

20-
check(12.0, InplaceableThunk(@thunk(2.0), -> error("Should not have in-placed")))
20+
check(12.0, InplaceableThunk(X̄ -> error("Should not have in-placed"), @thunk(2.0)))
2121

22-
check([10.0, 20.0], InplaceableThunk(@thunk([2.0, 0.0]), -> (X̄[1] += 2.0; X̄)))
22+
check([10.0, 20.0], InplaceableThunk(X̄ -> (X̄[1] += 2.0; X̄), @thunk([2.0, 0.0])))
2323
@test fails() do
24-
check([10.0, 20.0], InplaceableThunk(@thunk([2.0, 0.0]), -> (X̄[1] += 3.0; X̄)))
24+
check([10.0, 20.0], InplaceableThunk(X̄ -> (X̄[1] += 3.0; X̄), @thunk([2.0, 0.0])))
2525
end
2626
end
2727

test/testers.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
9696
@testset "Correct definitions" begin
9797
local inplace_used
9898
function ChainRulesCore.frule((_, ẋ), ::typeof(identity), x::Array)
99-
= InplaceableThunk(@thunk(ẋ), ȧ -> (inplace_used = true; ȧ .+= ẋ))
99+
= InplaceableThunk-> (inplace_used = true; ȧ .+= ẋ), @thunk(ẋ))
100100
return identity(x), ẏ
101101
end
102102
function ChainRulesCore.rrule(::typeof(identity), x::Array)
103103
function identity_pullback(ȳ)
104-
x̄_ret = InplaceableThunk(
105-
@thunk(ȳ), ā -> (inplace_used = true; ā .+= ȳ)
106-
)
104+
x̄_ret = InplaceableThunk(@thunk(ȳ)) do ā
105+
inplace_used = true
106+
ā .+= ȳ
107+
end
107108
return (NoTangent(), x̄_ret)
108109
end
109110
return identity(x), identity_pullback
@@ -122,14 +123,14 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
122123
my_identity(value) = value # we will define bad rules on this
123124
function ChainRulesCore.frule((_, ẋ), ::typeof(my_identity), x::Array)
124125
# only the in-place part is incorrect
125-
= InplaceableThunk(@thunk(ẋ), ȧ -> ȧ .+= 200 .* ẋ)
126+
= InplaceableThunk-> ȧ .+= 200 .*, @thunk(ẋ))
126127
return my_identity(x), ẏ
127128
end
128129
function ChainRulesCore.rrule(::typeof(my_identity), x::Array)
129130
x_dims = size(x)
130131
function my_identity_pullback(ȳ)
131132
# only the in-place part is incorrect
132-
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ)
133+
x̄_ret = InplaceableThunk-> ā .+= 200 .* ȳ, @thunk(ȳ))
133134
return (NoTangent(), x̄_ret)
134135
end
135136
return my_identity(x), my_identity_pullback

0 commit comments

Comments
 (0)