Skip to content

Commit cfe3a81

Browse files
authored
Reduce complexity of ProjectTo constructor (#422)
By avoiding a lot of the kwarg machinery when it is not needed. Helps get some Diffractor performance back after the ProjectTo change.
1 parent f68224a commit cfe3a81

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.0.2"
3+
version = "1.1.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ The type `T` is meant to encode the largest acceptable space, so usually
77
this enforces `p(dx)::T`. But some subspaces which aren't subtypes of `T` may
88
be allowed, and in particular `dx::AbstractZero` always passes through.
99
10-
Usually `T` is the "outermost" part of the type, and `p` stores additional
10+
Usually `T` is the "outermost" part of the type, and `p` stores additional
1111
properties such as projectors for each constituent field.
12-
Arrays have either one projector `p.element` expressing the element type for
12+
Arrays have either one projector `p.element` expressing the element type for
1313
an array of numbers, or else an array of projectors `p.elements`.
1414
These properties can be supplied as keyword arguments on construction,
1515
`p = ProjectTo{T}(; field=data, element=Projector(x))`. For each `T` in use,
@@ -21,7 +21,19 @@ struct ProjectTo{P,D<:NamedTuple}
2121
info::D
2222
end
2323
ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info)
24-
ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs))
24+
25+
# We'd like to write
26+
# ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs))
27+
#
28+
# but the kwarg dispatcher has non-trivial complexity. See rules.jl for an
29+
# explanation of this trick.
30+
const EMPTY_NT = NamedTuple()
31+
ProjectTo{P}() where {P} = ProjectTo{P}(EMPTY_NT)
32+
33+
const Type_kwfunc = Core.kwftype(Type).instance
34+
function (::typeof(Type_kwfunc))(kws::Any, ::Type{ProjectTo{P}}) where {P}
35+
ProjectTo{P}(NamedTuple(kws))
36+
end
2537

2638
Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name)
2739
Base.propertynames(p::ProjectTo) = propertynames(backing(p))
@@ -41,13 +53,13 @@ function Base.show(io::IO, project::ProjectTo{T}) where {T}
4153
end
4254

4355
# Structs
44-
# Generic method is to recursively make `ProjectTo`s for all their fields. Not actually
56+
# Generic method is to recursively make `ProjectTo`s for all their fields. Not actually
4557
# used on unknown structs, but useful for handling many known ones in the same manner.
4658
function generic_projector(x::T; kw...) where {T}
4759
fields_nt::NamedTuple = backing(x)
4860
fields_proj = map(_maybe_projector, fields_nt)
4961
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
50-
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
62+
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
5163
# but if it doesn't `construct` will give a good error message.
5264
wrapT = T.name.wrapper
5365
# Official API for this? https://github.com/JuliaLang/julia/issues/35543
@@ -100,8 +112,8 @@ true
100112
101113
julia> unthunk(pd(th))
102114
3×3 Diagonal{Float64, Vector{Float64}}:
103-
1.0 ⋅ ⋅
104-
⋅ 5.0 ⋅
115+
1.0 ⋅ ⋅
116+
⋅ 5.0 ⋅
105117
⋅ ⋅ 9.0
106118
107119
julia> ProjectTo([1 2; 3 4]') # no special structure, integers are promoted to float(x)
@@ -156,7 +168,7 @@ end
156168
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers
157169
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
158170
# Number type that might not be a subtype of the `project_type`.
159-
(::ProjectTo{<:Number})(dx::Number) = dx
171+
(::ProjectTo{<:Number})(dx::Number) = dx
160172

161173
(project::ProjectTo{<:Real})(dx::Complex) = project(real(dx))
162174
(project::ProjectTo{<:Complex})(dx::Real) = project(complex(dx))
@@ -407,7 +419,7 @@ function (project::ProjectTo{SparseVector})(dx::SparseVector)
407419
# When sparsity pattern is unchanged, all the time is in checking this,
408420
# perhaps some simple hash/checksum might be good enough?
409421
samepattern = project.nzind == dx.nzind
410-
# samepattern = length(project.nzind) == length(dx.nzind)
422+
# samepattern = length(project.nzind) == length(dx.nzind)
411423
if eltype(dx) <: project_type(project.element) && samepattern
412424
return dx
413425
elseif samepattern

src/rule_definition_tools.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
169169
return @strip_linenos quote
170170
# _ is the input derivative w.r.t. function internals. since we do not
171171
# allow closures/functors with @scalar_rule, it is always ignored
172-
function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...))
172+
function ChainRulesCore.frule((_, $(Δs...)), ::Core.Typeof($f), $(inputs...))
173173
$(__source__)
174174
$(esc()) = $call
175175
$(setup_stmts...)
@@ -206,7 +206,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
206206
end
207207

208208
return @strip_linenos quote
209-
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
209+
function ChainRulesCore.rrule(::Core.Typeof($f), $(inputs...))
210210
$(__source__)
211211
$(esc()) = $call
212212
$(setup_stmts...)
@@ -233,7 +233,7 @@ end
233233
Returns the expression for the propagation of
234234
the input gradient `Δs` though the partials `∂s`.
235235
Specify `_conj = true` to conjugate the partials.
236-
Projector `proj` is a function that will be applied at the end;
236+
Projector `proj` is a function that will be applied at the end;
237237
for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity`
238238
"""
239239
function propagation_expr(Δs, ∂s, _conj=false, proj=identity)

0 commit comments

Comments
 (0)