Skip to content

Commit ba0695a

Browse files
authored
Merge pull request #527 from mcabbott/returns
Use `Returns`
2 parents 6bf0f30 + 2b186f7 commit ba0695a

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212
[compat]
1313
ChainRulesCore = "1.1"
1414
ChainRulesTestUtils = "1"
15-
Compat = "3.33"
15+
Compat = "3.35"
1616
FiniteDifferences = "0.12.8"
1717
JuliaInterpreter = "0.8"
1818
StaticArrays = "1.2"

src/rulesets/Base/array.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
5959
A_dims = size(A)
6060
function reshape_pullback(Ȳ)
6161
∂A = reshape(Ȳ, A_dims)
62-
∂dims = broadcast(_ -> NoTangent(), dims)
62+
∂dims = broadcast(Returns(NoTangent()), dims)
6363
return (NoTangent(), ∂A, ∂dims...)
6464
end
6565
return reshape(A, dims...), reshape_pullback
@@ -68,7 +68,7 @@ end
6868
#####
6969
##### `repeat`
7070
#####
71-
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
71+
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))
7272

7373
project_Xs = ProjectTo(xs)
7474
S = size(xs)
@@ -98,7 +98,7 @@ function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...)
9898
size2ndims = ntuple(d -> isodd(d) ? get(S, 1+d÷2, 1) : get(counts, d÷2, 1), 2*ndims(dY))
9999
reduced = sum(reshape(dY, size2ndims); dims = ntuple(d -> 2d, ndims(dY)))
100100
= project_Xs(reshape(reduced, S))
101-
return (NoTangent(), x̄, map(_->NoTangent(), counts)...)
101+
return (NoTangent(), x̄, map(Returns(NoTangent()), counts)...)
102102
end
103103
return repeat(xs, counts...), repeat_pullback
104104
end
@@ -303,7 +303,7 @@ function frule((_, xdot), ::typeof(reverse), x::AbstractArray, args...; kw...)
303303
end
304304

305305
function rrule(::typeof(reverse), x::AbstractArray, args...; kw...)
306-
nots = map(_ -> NoTangent(), args)
306+
nots = map(Returns(NoTangent()), args)
307307
function reverse_pullback(dy)
308308
dx = @thunk reverse(unthunk(dy), args...; kw...)
309309
return (NoTangent(), dx, nots...)
@@ -338,7 +338,7 @@ end
338338

339339
function rrule(::typeof(fill), x::Any, dims...)
340340
project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
341-
nots = map(_ -> NoTangent(), dims)
341+
nots = map(Returns(NoTangent()), dims)
342342
fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...)
343343
return fill(x, dims...), fill_pullback
344344
end

src/rulesets/Base/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
2020
getindex_add!,
2121
@thunk(getindex_add!(zero(x))),
2222
)
23-
īnds = broadcast(_ -> NoTangent(), inds)
23+
īnds = broadcast(Returns(NoTangent()), inds)
2424
return (NoTangent(), x̄, īnds...)
2525
end
2626

src/rulesets/Random/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ frule(Δargs, T::Type{<:AbstractRNG}, args...) = T(args...), ZeroTangent()
22

33
function rrule(T::Type{<:AbstractRNG}, args...)
44
function AbstractRNG_pullback(ΔΩ)
5-
return (NoTangent(), map(_ -> ZeroTangent(), args)...)
5+
return (NoTangent(), map(Returns(ZeroTangent()), args)...)
66
end
77
return T(args...), AbstractRNG_pullback
88
end

0 commit comments

Comments
 (0)