Skip to content

Commit 0095fb2

Browse files
authored
Merge pull request #154 from JuliaDiff/ox/noprimalvalue
Print primal type, not primal value
2 parents 72dad7f + fc8d524 commit 0095fb2

File tree

6 files changed

+222
-23
lines changed

6 files changed

+222
-23
lines changed

.github/workflows/Documenter.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- uses: actions/checkout@v2
1414
- uses: julia-actions/setup-julia@v1
1515
with:
16-
version: '1.5'
16+
version: '1.6'
1717
- uses: julia-actions/julia-docdeploy@latest
1818
env:
1919
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ docs/site/
2323
# It records a fixed state of all packages used by the project. As such, it should not be
2424
# committed for packages, but should be committed for applications that require a static
2525
# environment.
26-
Manifest.toml
26+
/Manifest.toml
2727

2828
# JetBrains meta files
2929
.idea/*

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.12"
3+
version = "0.6.13"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/Manifest.toml

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
[[ArgTools]]
4+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
5+
6+
[[Artifacts]]
7+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
8+
9+
[[Base64]]
10+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
11+
12+
[[ChainRulesCore]]
13+
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14+
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
15+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16+
version = "0.9.44"
17+
18+
[[ChainRulesTestUtils]]
19+
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
20+
path = ".."
21+
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22+
version = "0.6.13"
23+
24+
[[Compat]]
25+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
26+
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
27+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
28+
version = "3.30.0"
29+
30+
[[Dates]]
31+
deps = ["Printf"]
32+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
33+
34+
[[DelimitedFiles]]
35+
deps = ["Mmap"]
36+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
37+
38+
[[Distributed]]
39+
deps = ["Random", "Serialization", "Sockets"]
40+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
41+
42+
[[DocStringExtensions]]
43+
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
44+
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
45+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
46+
version = "0.8.4"
47+
48+
[[Documenter]]
49+
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
50+
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
51+
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
52+
version = "0.26.3"
53+
54+
[[Downloads]]
55+
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
56+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
57+
58+
[[FiniteDifferences]]
59+
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
60+
git-tree-sha1 = "8662836e29702fdfdb1b90cbe4162e31b94f1e51"
61+
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
62+
version = "0.12.7"
63+
64+
[[IOCapture]]
65+
deps = ["Logging"]
66+
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
67+
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
68+
version = "0.1.1"
69+
70+
[[InteractiveUtils]]
71+
deps = ["Markdown"]
72+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
73+
74+
[[JSON]]
75+
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
76+
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
77+
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
78+
version = "0.21.1"
79+
80+
[[LibCURL]]
81+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
82+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
83+
84+
[[LibCURL_jll]]
85+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
86+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
87+
88+
[[LibGit2]]
89+
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
90+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
91+
92+
[[LibSSH2_jll]]
93+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
94+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
95+
96+
[[Libdl]]
97+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
98+
99+
[[LinearAlgebra]]
100+
deps = ["Libdl"]
101+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
102+
103+
[[Logging]]
104+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
105+
106+
[[Markdown]]
107+
deps = ["Base64"]
108+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
109+
110+
[[MbedTLS_jll]]
111+
deps = ["Artifacts", "Libdl"]
112+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
113+
114+
[[Mmap]]
115+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
116+
117+
[[MozillaCACerts_jll]]
118+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
119+
120+
[[NetworkOptions]]
121+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
122+
123+
[[Parsers]]
124+
deps = ["Dates"]
125+
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
126+
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
127+
version = "1.1.0"
128+
129+
[[Pkg]]
130+
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
131+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
132+
133+
[[Printf]]
134+
deps = ["Unicode"]
135+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
136+
137+
[[REPL]]
138+
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
139+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
140+
141+
[[Random]]
142+
deps = ["Serialization"]
143+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
144+
145+
[[Richardson]]
146+
deps = ["LinearAlgebra"]
147+
git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949"
148+
uuid = "708f8203-808e-40c0-ba2d-98a6953ed40d"
149+
version = "1.4.0"
150+
151+
[[SHA]]
152+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
153+
154+
[[Serialization]]
155+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
156+
157+
[[SharedArrays]]
158+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
159+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
160+
161+
[[Sockets]]
162+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
163+
164+
[[SparseArrays]]
165+
deps = ["LinearAlgebra", "Random"]
166+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
167+
168+
[[StaticArrays]]
169+
deps = ["LinearAlgebra", "Random", "Statistics"]
170+
git-tree-sha1 = "a1f226ebe197578c25fcf948bfff3d0d12f2ff20"
171+
uuid = "90137ffa-7385-5640-81b9-e52037218182"
172+
version = "1.2.1"
173+
174+
[[Statistics]]
175+
deps = ["LinearAlgebra", "SparseArrays"]
176+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
177+
178+
[[TOML]]
179+
deps = ["Dates"]
180+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
181+
182+
[[Tar]]
183+
deps = ["ArgTools", "SHA"]
184+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
185+
186+
[[Test]]
187+
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
188+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
189+
190+
[[UUIDs]]
191+
deps = ["Random", "SHA"]
192+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
193+
194+
[[Unicode]]
195+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
196+
197+
[[Zlib_jll]]
198+
deps = ["Libdl"]
199+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
200+
201+
[[nghttp2_jll]]
202+
deps = ["Artifacts", "Libdl"]
203+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
204+
205+
[[p7zip_jll]]
206+
deps = ["Artifacts", "Libdl"]
207+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

docs/src/index.md

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,11 @@ The call will test the `frule` for function `f` at the point `x` in the domain.
5656
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.
5757

5858
```jldoctest ex; output = false
59-
using ChainRulesTestUtils
59+
julia> using ChainRulesTestUtils;
6060
61-
test_frule(two2three, 3.33, -7.77);
62-
63-
# output
64-
Test Summary: | Pass Total
65-
test_frule: two2three at (3.33, -7.77) | 5 5
66-
Test.DefaultTestSet("test_frule: two2three at (3.33, -7.77)", Any[Test.DefaultTestSet("Tuple{Float64,Float64,Float64}.1", Any[], 1, false), Test.DefaultTestSet("Tuple{Float64,Float64,Float64}.2", Any[], 1, false), Test.DefaultTestSet("Tuple{Float64,Float64,Float64}.3", Any[], 1, false)], 2, false)
61+
julia> test_frule(two2three, 3.33, -7.77);
62+
Test Summary: | Pass Total
63+
test_frule: two2three on Float64,Float64 | 5 5
6764
```
6865

6966
### Testing the `rrule`
@@ -72,12 +69,9 @@ Test.DefaultTestSet("test_frule: two2three at (3.33, -7.77)", Any[Test.DefaultTe
7269
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.
7370

7471
```jldoctest ex; output = false
75-
test_rrule(two2three, 3.33, -7.77);
76-
77-
# output
78-
Test Summary: | Pass Total
79-
test_rrule: two2three at (3.33, -7.77) | 6 6
80-
Test.DefaultTestSet("test_rrule: two2three at (3.33, -7.77)", Any[Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)], 6, false)
72+
julia> test_rrule(two2three, 3.33, -7.77);
73+
Test Summary: | Pass Total
74+
test_rrule: two2three on Float64,Float64 | 6 6
8175
```
8276

8377
## Scalar example
@@ -102,15 +96,13 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
10296
`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
10397
call.
10498
```jldoctest ex; output = false
105-
test_scalar(relu, 0.5);
106-
test_scalar(relu, -0.5);
107-
108-
# output
99+
julia> test_scalar(relu, 0.5);
109100
Test Summary: | Pass Total
110101
test_scalar: relu at 0.5 | 7 7
102+
103+
julia> test_scalar(relu, -0.5);
111104
Test Summary: | Pass Total
112105
test_scalar: relu at -0.5 | 7 7
113-
Test.DefaultTestSet("test_scalar: relu at -0.5", Any[Test.DefaultTestSet("with tangent 1.0", Any[Test.DefaultTestSet("test_frule: relu at (ChainRulesTestUtils.PrimalAndTangent{Float64,Float64}(-0.5, 1.0),)", Any[], 3, false)], 0, false), Test.DefaultTestSet("with cotangent 1.0", Any[Test.DefaultTestSet("test_rrule: relu at (ChainRulesTestUtils.PrimalAndTangent{Float64,Float64}(-0.5, 1.0),)", Any[Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)], 4, false)], 0, false)], 0, false)
114106
```
115107

116108
## Specifying Tangents

src/testers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function test_frule(
9797
# To simplify some of the calls we make later lets group the kwargs for reuse
9898
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
9999

100-
@testset "test_frule: $f at $inputs" begin
100+
@testset "test_frule: $f on $(join(typeof.(inputs), ","))" begin
101101
_ensure_not_running_on_functor(f, "test_frule")
102102

103103
xẋs = auto_primal_and_tangent.(inputs)
@@ -164,7 +164,7 @@ function test_rrule(
164164
# To simplify some of the calls we make later lets group the kwargs for reuse
165165
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
166166

167-
@testset "test_rrule: $f at $inputs" begin
167+
@testset "test_rrule: $f on $(join(typeof.(inputs), ","))" begin
168168
_ensure_not_running_on_functor(f, "test_rrule")
169169

170170
# Check correctness of evaluation.

0 commit comments

Comments
 (0)