Skip to content

Commit dffcfff

Browse files
authored
Merge pull request #177 from JuliaDiff/ox/typo
fix mistake in non-tuple output message
2 parents e21bf34 + 68fc041 commit dffcfff

File tree

8 files changed

+72
-25
lines changed

8 files changed

+72
-25
lines changed

.github/workflows/fix_doctests.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: fix_doctests
2+
on:
3+
pull_request:
4+
jobs:
5+
doctests:
6+
name: Fix doctests (Julia ${{ matrix.julia-version }} - ${{ github.event_name }})
7+
runs-on: ubuntu-latest
8+
strategy:
9+
matrix:
10+
julia-version: [1.6]
11+
steps:
12+
- uses: julia-actions/setup-julia@latest
13+
with:
14+
version: ${{ matrix.julia-version }}
15+
- uses: actions/checkout@v1
16+
- name: Fix doctests
17+
shell: julia --project=docs/ {0}
18+
run: |
19+
using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()
20+
using Documenter
21+
using ChainRulesTestUtils
22+
doctest(ChainRulesTestUtils, fix=true)
23+
# don't push changes to Manifest in suggestions, as it removes `path=..`
24+
run(`git restore docs/Manifest.toml`)
25+
- uses: reviewdog/action-suggester@v1
26+
if: github.event_name == 'pull_request'
27+
with:
28+
tool_name: Documenter (fix doctests)
29+
fail_on_error: true

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.7.11"
3+
version = "0.7.12"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/Manifest.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
14+
git-tree-sha1 = "dbc9aae1227cfddaa9d2552f3ecba5b641f6cce9"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.4"
16+
version = "0.10.5"
1717

1818
[[ChainRulesTestUtils]]
1919
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2020
path = ".."
2121
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22-
version = "0.7.9"
22+
version = "0.7.12"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -47,9 +47,9 @@ version = "0.8.5"
4747

4848
[[Documenter]]
4949
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
50-
git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649"
50+
git-tree-sha1 = "5acbebf1be22db43589bc5aa1bb5fcc378b17780"
5151
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
52-
version = "0.26.3"
52+
version = "0.27.0"
5353

5454
[[Downloads]]
5555
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
@@ -62,10 +62,10 @@ uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
6262
version = "0.12.13"
6363

6464
[[IOCapture]]
65-
deps = ["Logging"]
66-
git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59"
65+
deps = ["Logging", "Random"]
66+
git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a"
6767
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
68-
version = "0.1.1"
68+
version = "0.2.2"
6969

7070
[[InteractiveUtils]]
7171
deps = ["Markdown"]

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55

66
[compat]
7-
Documenter = "0.26"
7+
Documenter = "0.27"
88
julia = "1"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ makedocs(;
1212
],
1313
strict=true,
1414
checkdocs=:exports,
15-
)
15+
)
1616

1717
const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
1818
deploydocs(; repo=repo, push_preview=true)

docs/src/index.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ For information about ChainRules, including how to write rules, refer to the gen
1212
## Canonical example
1313

1414
Let's suppose a custom transformation has been defined
15-
```jldoctest ex; output = false
15+
```jldoctest ex
1616
function two2three(x1::Float64, x2::Float64)
1717
return 1.0, 2.0*x1, 3.0*x2
1818
end
@@ -21,7 +21,7 @@ end
2121
two2three (generic function with 1 method)
2222
```
2323
along with the `frule`
24-
```jldoctest ex; output = false
24+
```jldoctest ex
2525
using ChainRulesCore
2626
2727
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
@@ -33,7 +33,7 @@ end
3333
3434
```
3535
and `rrule`
36-
```jldoctest ex; output = false
36+
```jldoctest ex
3737
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
3838
y = two2three(x1, x2)
3939
function two2three_pullback(Ȳ)
@@ -55,29 +55,31 @@ They can be used for any type and number of inputs and outputs.
5555
The call will test the `frule` for function `f` at the point `x` in the domain.
5656
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.
5757

58-
```jldoctest ex; output = false
58+
```jldoctest ex
5959
julia> using ChainRulesTestUtils;
6060
6161
julia> test_frule(two2three, 3.33, -7.77);
6262
Test Summary: | Pass Total
6363
test_frule: two2three on Float64,Float64 | 6 6
64+
6465
```
6566

6667
### Testing the `rrule`
6768

6869
[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
6970
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.
7071

71-
```jldoctest ex; output = false
72+
```jldoctest ex
7273
julia> test_rrule(two2three, 3.33, -7.77);
7374
Test Summary: | Pass Total
74-
test_rrule: two2three on Float64,Float64 | 7 7
75+
test_rrule: two2three on Float64,Float64 | 8 8
76+
7577
```
7678

7779
## Scalar example
7880

7981
For functions with a single argument and a single output, such as e.g. ReLU,
80-
```jldoctest ex; output = false
82+
```jldoctest ex
8183
function relu(x::Real)
8284
return max(0, x)
8385
end
@@ -86,7 +88,7 @@ end
8688
relu (generic function with 1 method)
8789
```
8890
with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
89-
```jldoctest ex; output = false
91+
```jldoctest ex
9092
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
9193
9294
# output
@@ -95,14 +97,16 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
9597

9698
`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
9799
call.
98-
```jldoctest ex; output = false
100+
```jldoctest ex
99101
julia> test_scalar(relu, 0.5);
100102
Test Summary: | Pass Total
101-
test_scalar: relu at 0.5 | 9 9
103+
test_scalar: relu at 0.5 | 10 10
104+
102105
103106
julia> test_scalar(relu, -0.5);
104107
Test Summary: | Pass Total
105-
test_scalar: relu at -0.5 | 9 9
108+
test_scalar: relu at -0.5 | 10 10
109+
106110
```
107111

108112
## Testing constructors and functors (callable objects)

src/testers.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,14 @@ function test_rrule(
244244

245245
check_inferred && _test_inferred(pullback, ȳ)
246246
ad_cotangents = pullback(ȳ)
247-
ad_cotangents isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
248-
msg = "The pullback should return 1 cotangent for the primal and each primal input."
249-
@test_msg msg length(ad_cotangents) == 1 + length(args)
247+
@test_msg(
248+
"The pullback must return a Tuple (∂self, ∂args...)",
249+
ad_cotangents isa Tuple
250+
)
251+
@test_msg(
252+
"The pullback should return 1 cotangent for the primal and each primal input.",
253+
length(ad_cotangents) == length(primals)
254+
)
250255

251256
# Correctness testing via finite differencing.
252257
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113

test/testers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,15 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
548548
@test fails(() -> test_frule(foo, 2.1, 2.1))
549549
@test fails(() -> test_rrule(foo, 21.0, 32.0))
550550
end
551+
552+
@testset "rrule not returning a tuple" begin
553+
bar(x, y) = x + 3y
554+
function ChainRulesCore.rrule(::typeof(bar), x, y)
555+
bar_pullback(dy) = dy
556+
return bar(x,y), bar_pullback
557+
end
558+
@test fails(() -> test_rrule(bar, 21.0, 32.0))
559+
end
551560
end
552561

553562
@testset "structs" begin

0 commit comments

Comments
 (0)