Skip to content

Commit 170048b

Browse files
authored
Merge pull request #188 from JuliaDiff/ox/inplaceorder
reorder argument order on InplaceableThunk
2 parents 8b6527a + 1594206 commit 170048b

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.10.4"
14+
ChainRulesCore = "0.10.12"
1515
Compat = "3"
1616
FiniteDifferences = "0.12.12"
1717
julia = "1"

docs/Manifest.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "be770c08881f7bb928dfd86d1ba83798f76cf62a"
14+
git-tree-sha1 = "0b0aa9d61456940511416b59a0e902c57b154956"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.9"
16+
version = "0.10.12"
1717

1818
[[ChainRulesTestUtils]]
1919
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2020
path = ".."
2121
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22-
version = "0.7.12"
22+
version = "0.7.13"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -47,19 +47,19 @@ version = "0.8.5"
4747

4848
[[Documenter]]
4949
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
50-
git-tree-sha1 = "621850838b3e74dd6dd047b5432d2e976877104e"
50+
git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48"
5151
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
52-
version = "0.27.2"
52+
version = "0.27.3"
5353

5454
[[Downloads]]
5555
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
5656
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
5757

5858
[[FiniteDifferences]]
5959
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
60-
git-tree-sha1 = "bdc9fb1d27a1ccecd2fe8f39c6211524cbe642cb"
60+
git-tree-sha1 = "12417e4754486a547d98d65293dc0fafdfcc0736"
6161
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
62-
version = "0.12.13"
62+
version = "0.12.14"
6363

6464
[[IOCapture]]
6565
deps = ["Logging", "Random"]
@@ -167,9 +167,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
167167

168168
[[StaticArrays]]
169169
deps = ["LinearAlgebra", "Random", "Statistics"]
170-
git-tree-sha1 = "745914ebcd610da69f3cb6bf76cb7bb83dcb8c9a"
170+
git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e"
171171
uuid = "90137ffa-7385-5640-81b9-e52037218182"
172-
version = "1.2.4"
172+
version = "1.2.6"
173173

174174
[[Statistics]]
175175
deps = ["LinearAlgebra", "SparseArrays"]

test/check_result.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ end
1717
check(11.0, ZeroTangent())
1818
check([10.0, 20.0], @thunk([2.0, 0.0]))
1919

20-
check(12.0, InplaceableThunk(@thunk(2.0), -> error("Should not have in-placed")))
20+
check(12.0, InplaceableThunk(X̄ -> error("Should not have in-placed"), @thunk(2.0)))
2121

22-
check([10.0, 20.0], InplaceableThunk(@thunk([2.0, 0.0]), -> (X̄[1] += 2.0; X̄)))
22+
check([10.0, 20.0], InplaceableThunk(X̄ -> (X̄[1] += 2.0; X̄), @thunk([2.0, 0.0])))
2323
@test fails() do
24-
check([10.0, 20.0], InplaceableThunk(@thunk([2.0, 0.0]), -> (X̄[1] += 3.0; X̄)))
24+
check([10.0, 20.0], InplaceableThunk(X̄ -> (X̄[1] += 3.0; X̄), @thunk([2.0, 0.0])))
2525
end
2626
end
2727

test/testers.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
9696
@testset "Correct definitions" begin
9797
local inplace_used
9898
function ChainRulesCore.frule((_, ẋ), ::typeof(identity), x::Array)
99-
= InplaceableThunk(@thunk(ẋ), ȧ -> (inplace_used = true; ȧ .+= ẋ))
99+
= InplaceableThunk-> (inplace_used = true; ȧ .+= ẋ), @thunk(ẋ))
100100
return identity(x), ẏ
101101
end
102102
function ChainRulesCore.rrule(::typeof(identity), x::Array)
103103
function identity_pullback(ȳ)
104-
x̄_ret = InplaceableThunk(
105-
@thunk(ȳ), ā -> (inplace_used = true; ā .+= ȳ)
106-
)
104+
x̄_ret = InplaceableThunk(@thunk(ȳ)) do ā
105+
inplace_used = true
106+
ā .+= ȳ
107+
end
107108
return (NoTangent(), x̄_ret)
108109
end
109110
return identity(x), identity_pullback
@@ -122,14 +123,14 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
122123
my_identity(value) = value # we will define bad rules on this
123124
function ChainRulesCore.frule((_, ẋ), ::typeof(my_identity), x::Array)
124125
# only the in-place part is incorrect
125-
= InplaceableThunk(@thunk(ẋ), ȧ -> ȧ .+= 200 .* ẋ)
126+
= InplaceableThunk-> ȧ .+= 200 .*, @thunk(ẋ))
126127
return my_identity(x), ẏ
127128
end
128129
function ChainRulesCore.rrule(::typeof(my_identity), x::Array)
129130
x_dims = size(x)
130131
function my_identity_pullback(ȳ)
131132
# only the in-place part is incorrect
132-
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ)
133+
x̄_ret = InplaceableThunk-> ā .+= 200 .* ȳ, @thunk(ȳ))
133134
return (NoTangent(), x̄_ret)
134135
end
135136
return my_identity(x), my_identity_pullback

0 commit comments

Comments
 (0)