Skip to content

Commit 488710b

Browse files
committed
fix tests
1 parent 509a353 commit 488710b

File tree

6 files changed

+18
-12
lines changed

6 files changed

+18
-12
lines changed

src/rulesets/Base/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ end
240240
rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero
241241
rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero
242242
rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero
243-
rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero
243+
rrule(::RCR, ::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero
244244
rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) =
245245
rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero
246246
rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero

src/tuplecast.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
tuplecast(f, args...)
44
55
For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`,
6-
but performed using `StructArrays` for efficiency.
6+
but performed using `StructArrays` for efficiency. Used in the gradient of broadcasting.
77
88
# Examples
99
```
@@ -52,7 +52,8 @@ function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplec
5252
return z, untuplecast
5353
end
5454

55-
function rrule(cfg::RCR, ::typeof(collecttuplecast), f, args...) # for testing, but doesn't work?
55+
# This is for testing, but the tests using it don't work.
56+
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collecttuplecast), f, args...)
5657
y, back = rrule(cfg, tuplecast, f, args...)
5758
return collect(y), back
5859
end
@@ -62,6 +63,8 @@ end
6263
6364
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
6465
but performed using `StructArrays` for efficiency.
66+
67+
Not in use at present, but see `tuplecast`.
6568
"""
6669
function tuplemap(f::F, args...) where {F}
6770
T = Broadcast.combine_eltypes(f, args)

test/rulesets/Base/broadcast.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ using Base.Broadcast: broadcasted
6060
test_rrule(copybroadcasted, /, (rand(2),), rand(3), check_inferred=false)
6161
end
6262

63-
@testset "lazy rules" begin
63+
@testset "fused rules" begin
6464
@testset "arithmetic" begin
6565
test_rrule(copybroadcasted, +, rand(3), rand(3))
6666
test_rrule(copybroadcasted, +, rand(3), rand(4)')
@@ -88,8 +88,9 @@ using Base.Broadcast: broadcasted
8888
@test y4 == [im, 2im, 3im]
8989
@test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7]
9090

91-
test_rrule(copybroadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3))
92-
test_rrule(copybroadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)')
91+
test_rrule(copybroadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3), check_inferred=false) # Union{NoTangent, ZeroTangent}
92+
test_rrule(copybroadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)', check_inferred=false) # Union{NoTangent, ZeroTangent}
93+
# (These two may infer with vararg rrule)
9394

9495
test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3), Val(2))
9596
test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2))
@@ -134,6 +135,7 @@ using Base.Broadcast: broadcasted
134135
test_rrule(copybroadcasted, -, rand(), rand())
135136
test_rrule(copybroadcasted, -, rand())
136137
test_rrule(copybroadcasted, *, rand(), rand())
138+
test_rrule(copybroadcasted, *, rand(), rand(), rand(), rand())
137139
test_rrule(copybroadcasted, Base.literal_pow, ^, rand(), Val(2))
138140
test_rrule(copybroadcasted, /, rand(), rand())
139141
end

test/rulesets/Base/mapreduce.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
33
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
44

5-
const CFG = ChainRulesTestUtils.ADviaRuleConfig()
6-
75
@testset "Reductions" begin
86
@testset "sum(::Tuple)" begin
97
test_frule(sum, Tuple(rand(5)))

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ else
4343
end
4444

4545
@testset "ChainRules" begin # One overall @testset ensures it keeps going after failures
46-
include("test_helpers.jl")
47-
include("tuplecast.jl")
46+
include("test_helpers.jl") # This can't be skipped
4847
println()
4948

5049
test_method_tables() # Check the global method tables are consistent
@@ -60,6 +59,8 @@ end
6059
include_test("rulesets/Base/sort.jl")
6160
include_test("rulesets/Base/broadcast.jl")
6261

62+
include_test("tuplecast.jl") # used primarily for broadcast
63+
6364
println()
6465

6566
include_test("rulesets/Statistics/statistics.jl")

test/test_helpers.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const CFG = ChainRulesTestUtils.TestConfig()
2+
13
"""
24
Multiplier(x)
35
@@ -80,8 +82,8 @@ fstar(A, B) = A * B
8082
ChainRulesCore.frule((_, ΔA, ΔB), ::typeof(fstar), A, B) = A * B, muladd(ΔA, B, A * ΔB)
8183

8284
"A version of `log` with only an `frule` defined"
83-
flog(x:::Number) = log(x)
84-
ChainRulesCore.frule((_, xdot), ::typeof(flog), x::Number) = log(x), inv(x) * xdot
85+
flog(x::Number) = log(x)
86+
ChainRulesCore.frule((_, Δx), ::typeof(flog), x::Number) = log(x), inv(x) * Δx
8587

8688
@testset "test_helpers.jl" begin
8789

0 commit comments

Comments
 (0)