Description
At the moment we only define ProjectTo
for differential types. (With Ref being the first exception)
Consider a generic rrule(*, a::Number, b::Number)
which uses ProjectTo
to ensure that the tangents are in the right subspace, i.e. something like
julia> function rrule(*, a::Number, b::Number)
function times_pullback(dy)
da = dy * b
db = a * dy
return NoTangent(), ProjectTo(a)(da), ProjectTo(b)(db)
end
return a*b, times_pullback
end
which looks perfectly reasonable.
However, if we create a type like
julia> struct PositiveReal <: Number
val::Float64
PositiveReal(x) = x > 0 ? new(x) : error("must be larger than 0")
end
which is not its own differential type (the natural differential for this is a Float64
) we are in trouble.
The problem is that since we only promise ProjectTo
to project onto valid differential types, so we can't just define
julia> function ProjectTo(x::PositiveReal)
return ProjectTo(x.val)
end
since PositiveReal
is not a valid differential type (does not have a zero). For similar reason we do not define ProjectTo(::Tuple)
, which would solve issues like #440.
The question is: should we loosen this requirement to only project onto differential types? By keeping the requirement we are restricting the use of ProjectTo
to functions with arguments that are their own differentials. What bad things happen if we scratch this ProjectTo
requirement?