Skip to content

Commit 0da97a0

Browse files
committed
add is_non_differentiable
1 parent e364b81 commit 0da97a0

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export frule_via_ad, rrule_via_ad
1313
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
1414
export ProjectTo, canonicalize, unthunk # tangent operations
1515
export add!! # gradient accumulation operations
16-
export ignore_derivatives, @ignore_derivatives
16+
export ignore_derivatives, @ignore_derivatives, is_non_differentiable
1717
# tangents
1818
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1919

src/projection.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,42 @@ end
141141
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
142142
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx
143143

144+
#####
145+
##### A related utility which wants to live nearby
146+
#####
147+
148+
"""
149+
is_non_differentiable(x) == is_non_differentiable(typeof(x))
150+
151+
Returns `true` if `x` is known from its type not to have derivatives, else `false`.
152+
153+
Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
154+
which is what the fallback method checks. The exception is that it will not look
155+
inside abstractly typed containers like `x = Any[true, false]`.
156+
"""
157+
is_non_differentiable(x) = is_non_differentiable(typeof(x))
158+
159+
is_non_differentiable(::Type{<:Number}) = false
160+
is_non_differentiable(::Type{<:NTuple{N,T}}) where {N,T} = is_non_differentiable(T)
161+
is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T)
162+
163+
function is_non_differentiable(::Type{T}) where {T} # fallback
164+
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
165+
return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero}
166+
end
167+
144168
#####
145169
##### `Base`
146170
#####
147171

148172
# Bool
149173
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
174+
is_non_differentiable(::Type{Bool}) = true
150175

151176
# Other never-differentiable types
152177
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
153178
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
179+
@eval is_non_differentiable(::Type{<:$T}) = true
154180
end
155181

156182
# Numbers

0 commit comments

Comments
 (0)