Skip to content

Commit e07a369

Browse files
authored
Merge pull request #395 from JuliaDiff/mz/sort
rrule for sort and partialsort
2 parents 237d5f6 + 238b8b6 commit e07a369

File tree

5 files changed

+52
-1
lines changed

5 files changed

+52
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.57"
3+
version = "0.7.58"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("rulesets/Base/evalpoly.jl")
3636
include("rulesets/Base/array.jl")
3737
include("rulesets/Base/arraymath.jl")
3838
include("rulesets/Base/indexing.jl")
39+
include("rulesets/Base/sort.jl")
3940
include("rulesets/Base/mapreduce.jl")
4041

4142
include("rulesets/Statistics/statistics.jl")

src/rulesets/Base/sort.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...)
2+
inds = partialsortperm(xs, k; kwargs...)
3+
ys = xs[inds]
4+
5+
function partialsort_pullback(Δys)
6+
function partialsort_add!(Δxs)
7+
Δxs[inds] += Δys
8+
return Δxs
9+
end
10+
11+
Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!)
12+
13+
return NO_FIELDS, Δxs, DoesNotExist()
14+
end
15+
16+
return ys, partialsort_pullback
17+
end
18+
19+
function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
20+
inds = sortperm(xs; kwargs...)
21+
ys = xs[inds]
22+
23+
function sort_pullback(Δys)
24+
function sort_add!(Δxs)
25+
Δxs[inds] += Δys
26+
return Δxs
27+
end
28+
29+
Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!)
30+
31+
return NO_FIELDS, Δxs
32+
end
33+
return ys, sort_pullback
34+
end

test/rulesets/Base/sort.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
@testset "sort.jl" begin
2+
@testset "sort" begin
3+
a = rand(10)
4+
test_rrule(sort, a)
5+
test_rrule(sort, a; fkwargs=(;rev=true))
6+
end
7+
@testset "partialsort" begin
8+
a = rand(10)
9+
test_rrule(partialsort, a, 4 nothing)
10+
test_rrule(partialsort, a, 3:5 nothing)
11+
test_rrule(partialsort, a, 1:2:6 nothing)
12+
13+
test_rrule(partialsort, a, 4 nothing, fkwargs=(;rev=true))
14+
end
15+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ println("Testing ChainRules.jl")
3232
include_test("rulesets/Base/arraymath.jl")
3333
include_test("rulesets/Base/indexing.jl")
3434
include_test("rulesets/Base/mapreduce.jl")
35+
include_test("rulesets/Base/sort.jl")
3536
end
3637
println()
3738

0 commit comments

Comments
 (0)