|
| 1 | +##### |
| 2 | +##### `sort` |
| 3 | +##### |
| 4 | + |
1 | 5 | function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...)
|
2 | 6 | inds = partialsortperm(xs, k; kwargs...)
|
3 | 7 | ys = xs[inds]
|
@@ -33,3 +37,55 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
|
33 | 37 | end
|
34 | 38 | return ys, sort_pullback
|
35 | 39 | end
|
| 40 | + |
| 41 | +##### |
| 42 | +##### `sortslices` |
| 43 | +##### |
| 44 | + |
| 45 | +function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) |
| 46 | + p = sortperm(collect(eachslice(x; dims=dims)); kw...) |
| 47 | + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) |
| 48 | + function sortslices_pullback(dy) |
| 49 | + # No actual need to zero this, and if you didn't, then you could widen eltype |
| 50 | + # Also, you could use similar(dy) here not x, same size? |
| 51 | + dx = _zerolike_writeat(x, unthunk(dy), (), inds...) |
| 52 | + return (NoTangent(), ProjectTo(x)(dx)) |
| 53 | + end |
| 54 | + return x[inds...], sortslices_pullback |
| 55 | +end |
| 56 | + |
| 57 | +##### |
| 58 | +##### `unique` |
| 59 | +##### |
| 60 | + |
| 61 | +function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) |
| 62 | + axes_x = axes(x) |
| 63 | + y = unique(x; dims=dims) # accepts only dims=: or dims::Integer |
| 64 | + function unique_pullback(dy_raw) |
| 65 | + dy = unthunk(dy_raw) |
| 66 | + if length(x) == length(y) |
| 67 | + # Short-circuit for the case of all unique, since `mask` is fairly expensive: |
| 68 | + dx = reshape(dy, axes_x) |
| 69 | + return (NoTangent(), ProjectTo(x)(dx)) |
| 70 | + end |
| 71 | + |
| 72 | + if dims isa Colon |
| 73 | + xs, ys = vec(x), y |
| 74 | + else |
| 75 | + xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims)) |
| 76 | + end |
| 77 | + mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN]) |
| 78 | + mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1) |
| 79 | + keep = map(I -> I[1], findall(mask)) |
| 80 | + if dims isa Colon |
| 81 | + # The function `_zerolike_writeat` allows second derivatives. |
| 82 | + # Should perhaps eventually be shared with `getindex`. |
| 83 | + dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x) |
| 84 | + else |
| 85 | + inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) |
| 86 | + dx = _zerolike_writeat(x, dy, (), inds...) |
| 87 | + end |
| 88 | + return (NoTangent(), ProjectTo(x)(dx)) |
| 89 | + end |
| 90 | + return y, unique_pullback |
| 91 | +end |
0 commit comments