Skip to content

Commit ce78d3d

Browse files
mcabbottoxinabox
andauthored
Rules for sortslices, unique (#546)
* unique, take 1 * add shortcut, and tests * sortslices, too * fixup * Apply 3 suggestions Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * comment Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 605354c commit ce78d3d

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

src/rulesets/Base/sort.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#####
2+
##### `sort`
3+
#####
4+
15
function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...)
26
inds = partialsortperm(xs, k; kwargs...)
37
ys = xs[inds]
@@ -33,3 +37,55 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
3337
end
3438
return ys, sort_pullback
3539
end
40+
41+
#####
42+
##### `sortslices`
43+
#####
44+
45+
function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
46+
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
47+
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
48+
function sortslices_pullback(dy)
49+
# No actual need to zero this, and if you didn't, then you could widen eltype
50+
# Also, you could use similar(dy) here not x, same size?
51+
dx = _zerolike_writeat(x, unthunk(dy), (), inds...)
52+
return (NoTangent(), ProjectTo(x)(dx))
53+
end
54+
return x[inds...], sortslices_pullback
55+
end
56+
57+
#####
58+
##### `unique`
59+
#####
60+
61+
function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:)
62+
axes_x = axes(x)
63+
y = unique(x; dims=dims) # accepts only dims=: or dims::Integer
64+
function unique_pullback(dy_raw)
65+
dy = unthunk(dy_raw)
66+
if length(x) == length(y)
67+
# Short-circuit for the case of all unique, since `mask` is fairly expensive:
68+
dx = reshape(dy, axes_x)
69+
return (NoTangent(), ProjectTo(x)(dx))
70+
end
71+
72+
if dims isa Colon
73+
xs, ys = vec(x), y
74+
else
75+
xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims))
76+
end
77+
mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN])
78+
mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1)
79+
keep = map(I -> I[1], findall(mask))
80+
if dims isa Colon
81+
# The function `_zerolike_writeat` allows second derivatives.
82+
# Should perhaps eventually be shared with `getindex`.
83+
dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x)
84+
else
85+
inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))
86+
dx = _zerolike_writeat(x, dy, (), inds...)
87+
end
88+
return (NoTangent(), ProjectTo(x)(dx))
89+
end
90+
return y, unique_pullback
91+
end

test/rulesets/Base/sort.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,33 @@
1212

1313
test_rrule(partialsort, a, 4, fkwargs=(;rev=true))
1414
end
15+
16+
@testset "sortslices" begin
17+
test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2))
18+
test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last))
19+
test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false)
20+
21+
@test_throws Exception sortslices(Diagonal(1:3), dims=1)
22+
end
23+
24+
@testset "unique" begin
25+
# Trivial case, all unique:
26+
test_rrule(unique, rand(5))
27+
test_rrule(unique, rand(3,4))
28+
test_rrule(unique, rand(3,4); fkwargs=(; dims=2))
29+
30+
# Not all unique:
31+
@test rrule(unique, [1,1,2,3])[1] == [1,2,3]
32+
@test rrule(unique, [1,1,2,3])[2]([10,20,30]) == (NoTangent(), [10, 0, 20, 30])
33+
34+
@test rrule(unique, [1 2; 1 4])[1] == [1,2,4]
35+
@test rrule(unique, [1 2; 1 4])[2]([10,20,30]) == (NoTangent(), [10 20; 0 30])
36+
37+
@test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[1] == [1 2 2; 1 2 4]
38+
@test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[2]([10 20 30; 40 50 60])[2] == [10 20 0 30; 40 50 0 60]
39+
40+
@test rrule(unique, Diagonal([1,2,3]))[1] == [1,0,2,3]
41+
@test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] == [10.0 0.0 0.0; 0.0 30.0 0.0; 0.0 0.0 40.0]
42+
@test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] isa Diagonal
43+
end
1544
end

0 commit comments

Comments
 (0)