Skip to content

Commit 2c0ed36

Browse files
committed
Merge branch 'main' of https://github.com/JuliaDiff/ChainRules.jl into sparsedet
2 parents 3577abd + df672c3 commit 2c0ed36

File tree

12 files changed

+107
-91
lines changed

12 files changed

+107
-91
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- x86
2222
- x64
2323
steps:
24-
- uses: actions/checkout@v3.5.2
24+
- uses: actions/checkout@v3.5.3
2525
- uses: julia-actions/setup-julia@v1
2626
with:
2727
version: ${{ matrix.version }}

.github/workflows/IntegrationTest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ jobs:
2525
# package: {user: JuliaDiff, repo: Diffractor.jl}
2626

2727
steps:
28-
- uses: actions/checkout@v3.5.2
28+
- uses: actions/checkout@v3.5.3
2929
- uses: julia-actions/setup-julia@v1
3030
with:
3131
version: ${{ matrix.julia-version }}
3232
arch: x64
3333
- uses: julia-actions/julia-buildpkg@latest
3434
- name: Clone Downstream
35-
uses: actions/checkout@v3.5.2
35+
uses: actions/checkout@v3.5.3
3636
with:
3737
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
3838
path: downstream

.github/workflows/JuliaNightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- x86
2424
- x64
2525
steps:
26-
- uses: actions/checkout@v3.5.2
26+
- uses: actions/checkout@v3.5.3
2727
- uses: julia-actions/setup-julia@v1
2828
with:
2929
version: ${{ matrix.version }}

.github/workflows/VersionVigilante_pull_request.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
VersionVigilante:
77
runs-on: ubuntu-latest
88
steps:
9-
- uses: actions/checkout@v3.5.2
9+
- uses: actions/checkout@v3.5.3
1010
- uses: julia-actions/setup-julia@latest
1111
- name: VersionVigilante.main
1212
id: versionvigilante_main

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.49.0"
3+
version = "1.53.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/arraymath.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,13 @@ end # VERSION
210210
##### `muladd`
211211
#####
212212

213-
function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z)
213+
function frule(
214+
(_, ΔA, ΔB, Δz),
215+
::typeof(muladd),
216+
A::AbstractVecOrMat{<:CommutativeMulNumber},
217+
B::AbstractVecOrMat{<:CommutativeMulNumber},
218+
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}
219+
)
214220
Ω = muladd(A, B, z)
215221
return Ω, ΔA * B .+ A * ΔB .+ Δz
216222
end
@@ -342,20 +348,43 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
342348
project_B = ProjectTo(B)
343349

344350
Y = A \ B
351+
345352
function backslash_pullback(ȳ)
346353
= unthunk(ȳ)
354+
355+
Ȳf =
356+
@static if VERSION >= v"1.9"
357+
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
358+
if !isa(Ȳ, AbstractArray)
359+
Ȳf = [Ȳ]
360+
end
361+
end
362+
Yf = Y
363+
@static if VERSION >= v"1.9"
364+
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
365+
if !isa(Y, AbstractArray)
366+
Yf = [Y]
367+
end
368+
end
369+
#@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B)
347370
∂A = @thunk begin
348-
= A' \
371+
= A' \ Ȳf
349372
= -* Y'
350-
= add!!(Ā, (B - A * Y) *' / A')
351-
= add!!(Ā, A' \ Y * (Ȳ' -'A))
373+
t = (B - A * Y) *'
374+
@static if VERSION >= v"1.9"
375+
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
376+
if !isa(t, AbstractArray)
377+
t = [t]
378+
end
379+
end
380+
= add!!(Ā, t / A')
381+
= add!!(Ā, A' \ Yf * (Ȳ' -'A))
352382
project_A(Ā)
353383
end
354-
∂B = @thunk project_B(A' \ )
384+
∂B = @thunk project_B(A' \ Ȳf)
355385
return NoTangent(), ∂A, ∂B
356386
end
357387
return Y, backslash_pullback
358-
359388
end
360389

361390
#####

src/rulesets/Base/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
@scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent())
55

6-
@scalar_rule one(x) zero(x)
7-
@scalar_rule zero(x) zero(x)
6+
@scalar_rule one(x) ZeroTangent()
7+
@scalar_rule zero(x) ZeroTangent()
88
@scalar_rule transpose(x) true
99

1010
# `adjoint`

src/rulesets/Base/fastmath_able.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ let
169169
@scalar_rule x / y (one(x) / y, -/ y))
170170

171171
## many-arg +
172-
function frule((_, Δx, Δy...), ::typeof(+), x::Number, ys::Number...)
173-
+(x, ys...), +(Δx, Δy...)
172+
function frule(Δs, ::typeof(+), x::Number, ys::Number...)
173+
+(x, ys...), +(Base.tail(Δs)...)
174174
end
175175

176176
function rrule(::typeof(+), x::Number, ys::Number...)

src/rulesets/Base/indexing.jl

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1-
#####
2-
##### getindex(::Tuple)
3-
#####
4-
5-
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer)
6-
return x[i], ẋ[i]
7-
end
8-
9-
function frule((_, ẋ), ::typeof(getindex), x::Tuple, i)
10-
y = x[i]
11-
return y, Tangent{typeof(y)}(ẋ[i]...)
1+
# Int rather than Int64/Integer is intentional
2+
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
3+
return x.i, ẋ.i
124
end
135

146
"for a given tuple type, returns a Val{N} where N is the length of the tuple"
@@ -77,20 +69,52 @@ end
7769
"""
7870
∇getindex(x, dy, inds...)
7971
80-
For the `rrule` of `y = x[inds...]`, this function is roughly
72+
For the `rrule` of `y = x[inds...]`, this function is roughly
8173
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
8274
Differentiable. Includes `ProjectTo(x)(dx)`.
8375
"""
84-
function ∇getindex(x::AbstractArray, dy, inds...)
76+
function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N}
8577
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
8678
# leaving just Int / AbstractVector of Int
8779
plain_inds = Base.to_indices(x, inds)
88-
dx = _setindex_zero(x, dy, plain_inds...)
89-
∇getindex!(dx, dy, plain_inds...)
80+
dx = if plain_inds isa NTuple{N, Int} && T<:Number
81+
# scalar indexing
82+
OneElement(dy, plain_inds, axes(x))
83+
else # some from slicing (potentially noncontigous)
84+
dx = _setindex_zero(x, dy, plain_inds...)
85+
∇getindex!(dx, dy, plain_inds...)
86+
end
9087
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
9188
end
9289
∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z
9390

91+
"""
92+
OneElement(val, ind, axes) <: AbstractArray
93+
94+
Extremely simple `struct` used for the gradient of scalar `getindex`.
95+
"""
96+
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
97+
val::T
98+
ind::I
99+
axes::A
100+
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
101+
end
102+
Base.size(A::OneElement) = map(length, A.axes)
103+
Base.axes(A::OneElement) = A.axes
104+
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
105+
106+
function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N}
107+
if !ChainRulesCore.is_inplaceable_destination(xs)
108+
xs = collect(xs)
109+
end
110+
xs[oe.ind...] += oe.val
111+
return xs
112+
end
113+
114+
Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe)
115+
Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe)
116+
Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)
117+
94118
"""
95119
_setindex_zero(x, dy, inds...)
96120
@@ -159,29 +183,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
159183
return dx
160184
end
161185

162-
#####
163-
##### first, tail
164-
#####
165-
166-
function frule((_, ẋ), ::typeof(first), x::Tuple)
167-
return first(x), first(ẋ)
168-
end
169-
170-
function rrule(::typeof(first), x::T) where {T<:Tuple}
171-
first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...))
172-
return first(x), first_back
173-
end
174-
175-
function frule((_, ẋ), ::typeof(Base.tail), x::Tuple)
176-
y = Base.tail(x)
177-
return y, Tangent{typeof(y)}(Base.tail(ẋ)...)
178-
end
179-
180-
function rrule(::typeof(Base.tail), x::T) where {T<:Tuple}
181-
tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...))
182-
return Base.tail(x), tail_pullback
183-
end
184-
185186
#####
186187
##### view
187188
#####

test/rulesets/Base/array.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,15 @@ end
358358
@test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent()))
359359
@test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,))
360360
# These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
361+
# or by https://github.com/JuliaLang/julia/pull/48404
361362

362363
# Reverse
363364
test_rrule(findmin, rand(10), output_tangent = (rand(), false))
364365
test_rrule(findmax, rand(10), output_tangent = (rand(), false))
365-
test_rrule(findmin, rand(5,3))
366-
test_rrule(findmax, rand(5,3))
367-
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
368-
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
366+
test_rrule(findmin, rand(5,3); check_inferred=false)
367+
test_rrule(findmax, rand(5,3); check_inferred=false)
368+
@test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
369+
@test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
369370

370371
# Reverse with dims:
371372
@test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2])
@@ -385,7 +386,7 @@ end
385386

386387
# Reverse
387388
test_rrule(imum, rand(10))
388-
test_rrule(imum, rand(3,4))
389+
test_rrule(imum, rand(3,4); check_inferred=false)
389390
@gpu test_rrule(imum, rand(3,4), fkwargs=(dims=1,))
390391
test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),))
391392

0 commit comments

Comments
 (0)