Skip to content

Commit e245d50

Browse files
authored
Fix computing of array element type for ∇eachslice (#808)
1 parent 30f9b12 commit e245d50

File tree

4 files changed

+38
-12
lines changed

4 files changed

+38
-12
lines changed

.github/workflows/format.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ concurrency:
1212
jobs:
1313
format:
1414
runs-on: ubuntu-latest
15+
permissions:
16+
contents: read
17+
checks: write
18+
pull-requests: write
1519
steps:
1620
- uses: actions/checkout@v4
1721
- uses: julia-actions/setup-julia@latest
@@ -22,6 +26,7 @@ jobs:
2226
julia -e 'using JuliaFormatter; format("."; verbose=true)'
2327
- uses: reviewdog/action-suggester@v1
2428
with:
29+
github_token: ${{ secrets.GITHUB_TOKEN }}
2530
tool_name: JuliaFormatter
2631
fail_on_error: true
2732
filter_mode: added

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.70.0"
3+
version = "1.71.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2020

2121
[compat]
2222
Adapt = "3.4.0, 4"
23-
ChainRulesCore = "1.20"
23+
ChainRulesCore = "1.25"
2424
ChainRulesTestUtils = "1.5"
2525
Compat = "3.46, 4.2"
2626
Distributed = "1"

src/rulesets/Base/indexing.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
267267
if i1 === nothing # all slices are Zero!
268268
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
269269
end
270-
T = promote_type(eltype(dys[i1]), eltype(x))
270+
T = Base.promote_eltype(dys...)
271271
# The whole point of this gradient is that we can allocate one `dx` array:
272272
dx = similar(x, T, axes(x))
273273
for i in axes(x, dim)
@@ -282,8 +282,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
282282
end
283283
∇eachslice(dys::AbstractZero, x::AbstractArray, vd::Val{dim}) where {dim} = dys
284284

285-
_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx)))
286-
_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx)
285+
_zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx)))
287286

288287
function rrule(::typeof(∇eachslice), dys, x, vd::Val)
289288
function ∇∇eachslice(dz_raw)

test/rulesets/Base/indexing.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,24 @@ end
217217
# DimensionMismatch("second dimension of A, 6, does not match length of x, 5")
218218
# Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator)
219219

220-
test_rrule(collecteachrow, rand(5))
221-
test_rrule(collecteachrow, rand(3, 4))
220+
# Inference on 1.6 sometimes fails, so don't enforce there.
221+
test_rrule(collect eachrow, rand(5); check_inferred=(VERSION >= v"1.7"))
222+
test_rrule(collect eachrow, rand(3, 4); check_inferred=(VERSION >= v"1.7"))
222223

223-
test_rrule(collecteachcol, rand(3, 4))
224-
@test_skip test_rrule(collecteachcol, Diagonal(rand(5))) # works locally!
224+
test_rrule(collect eachcol, rand(3, 4); check_inferred=(VERSION >= v"1.7"))
225+
@test_skip test_rrule(collect eachcol, Diagonal(rand(5))) # works locally!
225226

226227
if VERSION >= v"1.7"
227228
# On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use.
228-
test_rrule(collecteachslice, rand(3, 4, 5); fkwargs = (; dims = 3))
229-
test_rrule(collecteachslice, rand(3, 4, 5); fkwargs = (; dims = (2,)))
229+
test_rrule(collect eachslice, rand(3, 4, 5); fkwargs=(; dims=3))
230+
test_rrule(collect eachslice, rand(3, 4, 5); fkwargs=(; dims=(2,)))
231+
232+
test_rrule(
233+
collect eachslice,
234+
FooTwoField.(rand(3, 4, 5), rand(3, 4, 5));
235+
check_inferred=false,
236+
fkwargs=(; dims=3),
237+
)
230238
end
231239

232240
# Make sure pulling back an array that mixes some AbstractZeros in works right
@@ -235,8 +243,22 @@ end
235243
@test back([1:3, ZeroTangent(), 7:9, NoTangent()])[2] isa Matrix{Float64}
236244
@test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0])
237245

246+
_, back = ChainRules.rrule(
247+
eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims=3
248+
)
249+
@test back([fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == (
250+
NoTangent(),
251+
cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3); dims=3),
252+
)
253+
238254
# Second derivative rule
239255
test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1))
240256
test_rrule(ChainRules.∇eachslice, [rand(3) for _ in 1:4], rand(3, 4), Val(2))
241-
test_rrule(ChainRules.∇eachslice, [rand(2, 3) for _ in 1:4], rand(2, 3, 4), Val(3), check_inferred=false)
257+
test_rrule(
258+
ChainRules.∇eachslice,
259+
[rand(2, 3) for _ in 1:4],
260+
rand(2, 3, 4),
261+
Val(3);
262+
check_inferred=(VERSION >= v"1.7"),
263+
)
242264
end

0 commit comments

Comments
 (0)