diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..323237bab --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 000000000..f6f268c0e --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,27 @@ +name: Format suggestions + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + tool_name: JuliaFormatter + fail_on_error: true + filter_mode: added diff --git a/docs/make.jl b/docs/make.jl index 1ef3a62a7..42e39a4c0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,20 +16,20 @@ DocMeta.setdocmeta!( @scalar_rule(sin(x), cos(x)) # frule and rrule doctest @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end + end, ) indigo = DocThemeIndigo.install(ChainRulesCore) -makedocs( +makedocs(; modules=[ChainRulesCore], - format=Documenter.HTML( + format=Documenter.HTML(; prettyurls=false, assets=[indigo], mathengine=MathJax3( Dict( :tex => Dict( - "inlineMath" => [["\$","\$"], ["\\(","\\)"]], + "inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], "tags" => "ams", # TODO: remove when using physics package "macros" => Dict( @@ -67,7 +67,4 @@ makedocs( checkdocs=:exports, ) -deploydocs( - repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", - push_preview=true, -) +deploydocs(; repo="github.com/JuliaDiff/ChainRulesCore.jl.git", push_preview=true) diff --git a/docs/src/assets/make_logo.jl b/docs/src/assets/make_logo.jl index 5bbfd36c1..c023c308f 100644 --- a/docs/src/assets/make_logo.jl +++ b/docs/src/assets/make_logo.jl @@ -9,79 +9,76 @@ using Random const bridge_len = 50 function chain(jiggle=0) - shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5)) - + shaky_rotate(θ) = rotate(θ + jiggle * (rand() - 0.5)) + ### 1 shaky_rotate(0) sethue(Luxor.julia_red) link() m1 = getmatrix() - - + ### 2 sethue(Luxor.julia_green) - translate(-50, 130); - shaky_rotate(π/3); + translate(-50, 130) + shaky_rotate(π / 3) link() m2 = getmatrix() - + setmatrix(m1) sethue(Luxor.julia_red) overlap(-1.3π) setmatrix(m2) - + ### 3 - shaky_rotate(-π/3); - translate(-120,80); + shaky_rotate(-π / 3) + translate(-120, 80) sethue(Luxor.julia_purple) link() - + setmatrix(m2) setcolor(Luxor.julia_green) - overlap(-1.5π) + return overlap(-1.5π) end - function link() sector(50, 90, π, 0, :fill) sector(Point(0, bridge_len), 50, 90, 0, -π, :fill) - - - rect(50,-3,40, bridge_len+6, :fill) - rect(-50-40,-3,40, bridge_len+6, :fill) - + + rect(50, -3, 40, bridge_len + 6, :fill) + rect(-50 - 40, -3, 40, bridge_len + 6, :fill) + sethue("black") move(Point(-50, bridge_len)) - arc(Point(0,0), 50, π, 0, :stoke) + arc(Point(0, 0), 50, π, 0, :stoke) arc(Point(0, bridge_len), 50, 0, -π, :stroke) - + move(Point(-90, bridge_len)) - arc(Point(0,0), 90, π, 0, :stoke) + arc(Point(0, 0), 90, π, 0, :stoke) arc(Point(0, bridge_len), 90, 0, -π, :stroke) - strokepath() + return strokepath() end function overlap(ang_end) - sector(Point(0, bridge_len), 50, 90, -0., ang_end, :fill) + sector(Point(0, bridge_len), 50, 90, -0.0, ang_end, :fill) sethue("black") arc(Point(0, bridge_len), 50, 0, ang_end, :stoke) move(Point(90, bridge_len)) arc(Point(0, bridge_len), 90, 0, ang_end, :stoke) - strokepath() + return strokepath() end # Actually draw it function save_logo(filename) Random.seed!(16) - Drawing(450,450, filename) + Drawing(450, 450, filename) origin() - translate(50, -130); + translate(50, -130) chain(0.5) finish() - preview() + return preview() end save_logo("logo.svg") -save_logo("logo.png") \ No newline at end of file +save_logo("logo.png") diff --git a/src/accumulation.jl b/src/accumulation.jl index 4bcc5c33f..c9a38956a 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -26,7 +26,7 @@ end add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) -function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N +function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N} return if is_inplaceable_destination(x) x .+= y else @@ -34,7 +34,6 @@ function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N end end - """ is_inplaceable_destination(x) -> Bool @@ -64,7 +63,6 @@ end is_inplaceable_destination(::LinearAlgebra.Hermitian) = false is_inplaceable_destination(::LinearAlgebra.Symmetric) = false - function debug_add!(accumuland, t::InplaceableThunk) returned_value = t.add!(accumuland) if returned_value !== accumuland @@ -88,7 +86,7 @@ function Base.showerror(io::IO, err::BadInplaceException) if err.accumuland == err.returned_value println( io, - "Which in this case happenned to be equal. But they are not the same object." + "Which in this case happenned to be equal. But they are not the same object.", ) end end diff --git a/src/compat.jl b/src/compat.jl index 8204b66d5..fa66b1d0f 100644 --- a/src/compat.jl +++ b/src/compat.jl @@ -5,7 +5,7 @@ end if VERSION < v"1.1" # Note: these are actually *better* than the ones in julia 1.1, 1.2, 1.3,and 1.4 # See: https://github.com/JuliaLang/julia/issues/34292 - function fieldtypes(::Type{T}) where T + function fieldtypes(::Type{T}) where {T} if @generated ntuple(i -> fieldtype(T, i), fieldcount(T)) else @@ -13,7 +13,7 @@ if VERSION < v"1.1" end end - function fieldnames(::Type{T}) where T + function fieldnames(::Type{T}) where {T} if @generated ntuple(i -> fieldname(T, i), fieldcount(T)) else diff --git a/src/config.jl b/src/config.jl index 347e05c51..04757e838 100644 --- a/src/config.jl +++ b/src/config.jl @@ -64,7 +64,6 @@ that do not support performing forwards mode AD should be `RuleConfig{>:NoForwar """ struct NoForwardsMode <: ForwardsModeCapability end - """ frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...) diff --git a/src/deprecated.jl b/src/deprecated.jl index e69de29bb..8b1378917 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -0,0 +1 @@ + diff --git a/src/ignore_derivatives.jl b/src/ignore_derivatives.jl index c66d89d7e..18865f2c9 100644 --- a/src/ignore_derivatives.jl +++ b/src/ignore_derivatives.jl @@ -45,7 +45,9 @@ ignore_derivatives(x) = x Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`. """ macro ignore_derivatives(ex) - return :(ChainRulesCore.ignore_derivatives() do - $(esc(ex)) - end) + return :( + ChainRulesCore.ignore_derivatives() do + $(esc(ex)) + end + ) end diff --git a/src/projection.jl b/src/projection.jl index 4b07b2762..a017493e4 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -32,7 +32,7 @@ ProjectTo{P}() where {P} = ProjectTo{P}(EMPTY_NT) const Type_kwfunc = Core.kwftype(Type).instance function (::typeof(Type_kwfunc))(kws::Any, ::Type{ProjectTo{P}}) where {P} - ProjectTo{P}(NamedTuple(kws)) + return ProjectTo{P}(NamedTuple(kws)) end Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) @@ -131,7 +131,10 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas # Also, any explicit construction with fields, where all fields project to zero, itself # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]). const _PZ = ProjectTo{<:AbstractZero} -ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}() +const _PZ_Tuple = Tuple{_PZ,Vararg{<:_PZ}} # 1 or more ProjectTo{<:AbstractZeros} +function ProjectTo{P}(::NamedTuple{T,<:_PZ_Tuple}) where {P,T} + return ProjectTo{NoTangent}() +end # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as @@ -164,12 +167,16 @@ for T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) end # In these cases we can just `convert` as we know we are dealing with plain and simple types -(::ProjectTo{T})(dx::AbstractFloat) where T<:AbstractFloat = convert(T, dx) -(::ProjectTo{T})(dx::Integer) where T<:AbstractFloat = convert(T, dx) #needed to avoid ambiguity +(::ProjectTo{T})(dx::AbstractFloat) where {T<:AbstractFloat} = convert(T, dx) +(::ProjectTo{T})(dx::Integer) where {T<:AbstractFloat} = convert(T, dx) #needed to avoid ambiguity # simple Complex{<:AbstractFloat}} cases -(::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +function (::ProjectTo{T})(dx::Complex{<:AbstractFloat}) where {T<:Complex{<:AbstractFloat}} + return convert(T, dx) +end (::ProjectTo{T})(dx::AbstractFloat) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) -(::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) +function (::ProjectTo{T})(dx::Complex{<:Integer}) where {T<:Complex{<:AbstractFloat}} + return convert(T, dx) +end (::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx) # Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through. @@ -244,9 +251,11 @@ end # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers if !(project.axes isa Tuple{}) - throw(DimensionMismatch( - "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number", - )) + throw( + DimensionMismatch( + "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number" + ), + ) end return fill(project.element(dx)) end @@ -298,7 +307,9 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end -ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +function ProjectTo(x::LinearAlgebra.TransposeAbsVec) + return ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) +end function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) return transpose(project.parent(transpose(dx))) end @@ -316,10 +327,8 @@ ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) # Symmetric -for (SymHerm, chk, fun) in ( - (:Symmetric, :issymmetric, :transpose), - (:Hermitian, :ishermitian, :adjoint), - ) +for (SymHerm, chk, fun) in + ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) @eval begin function ProjectTo(x::$SymHerm) sub = ProjectTo(parent(x)) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 911a32ddd..56e02b02a 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -99,11 +99,13 @@ macro scalar_rule(call, maybe_setup, partials...) rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro - code = quote + return quote if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw(ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $($f))" - )) + throw( + ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))" + ), + ) end $(derivative_expr) @@ -112,7 +114,6 @@ macro scalar_rule(call, maybe_setup, partials...) end end - """ _normalize_scalarrules_macro_input(call, maybe_setup, partials) @@ -175,7 +176,9 @@ function derivatives_given_output end function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) return @strip_linenos quote - function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + function ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), ::Core.Typeof($f), $(inputs...) + ) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) @@ -210,7 +213,9 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) return $(esc(:Ω)), $pushforward_returns end end @@ -225,7 +230,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) Δs = _propagator_inputs(n_outputs) # Make a projector for each argument - projs, psetup = _make_projectors(call.args[2:end]) + projs, psetup = _make_projectors(call.args[2:end]) append!(setup_stmts, psetup) # 1 partial derivative per input @@ -248,7 +253,9 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) return $(esc(:Ω)), $pullback end end @@ -262,7 +269,7 @@ _propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] "given the variable names, escaped but without types, makes setup expressions for projection operators" function _make_projectors(xs) projs = map(x -> Symbol(:proj_, x.args[1]), xs) - setups = map((x,p) -> :($p = ProjectTo($x)), xs, projs) + setups = map((x, p) -> :($p = ProjectTo($x)), xs, projs) return projs, setups end @@ -288,7 +295,8 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. init_expr = :(*($(_∂s[1]), $(Δs[1]))) - summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) + _∂s_Δs_tail = Iterators.drop(zip(_∂s, Δs), 1) + summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) :(muladd($∂s_i, $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) @@ -366,7 +374,7 @@ macro non_differentiable(sig_expr) primal_invoke = if !has_vararg :($(primal_name)($(unconstrained_args...))) else - normal_args = unconstrained_args[1:end-1] + normal_args = unconstrained_args[1:(end - 1)] var_arg = unconstrained_args[end] :($(primal_name)($(normal_args...), $(var_arg)...)) end @@ -381,7 +389,10 @@ end function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... + :call, + call_expr.args[1], + Expr(:parameters, :($(kwargs)...)), + call_expr.args[2:end]..., ) end @@ -389,11 +400,17 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), - frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function (::Core.kwftype(typeof(ChainRulesCore.frule)))( + @nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), + @nospecialize(::Any), + $(map(esc, primal_sig_parts)...), + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end - function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + function ChainRulesCore.frule( + @nospecialize(::Any), $(map(esc, primal_sig_parts)...) + ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() return ($(esc(primal_invoke)), NoTangent()) @@ -408,7 +425,8 @@ function tuple_expression(primal_sig_parts) Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + length_expr = + :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) end end @@ -426,7 +444,9 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) + function (::Core.kwftype(typeof(rrule)))( + $(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...) + ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) @@ -436,7 +456,6 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end end - ############################################################################################ # @opt_out @@ -481,7 +500,7 @@ end "Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) - length(expr.args)===0 && error("Malformed method expression. $expr") + length(expr.args) === 0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore @@ -502,8 +521,6 @@ function _no_rule_target_rewrite!(call_target::Symbol) end end - - ############################################################################################ # Helpers @@ -555,13 +572,13 @@ and one to use for calling that function """ function _split_primal_name(primal_name) # e.g. f(x, y) - if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || - Meta.isexpr(primal_name, :curly) - + is_plain = primal_name isa Symbol + is_qualified = Meta.isexpr(primal_name, :(.)) + is_parameterized = Meta.isexpr(primal_name, :curly) + if is_plain || is_qualified || is_parameterized primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name - # e.g. (::T)(x, y) - elseif Meta.isexpr(primal_name, :(::)) + elseif Meta.isexpr(primal_name, :(::)) # e.g. (::T)(x, y) _primal_name = gensym(Symbol(:instance_, primal_name.args[end])) primal_name_sig = Expr(:(::), _primal_name, primal_name.args[end]) return primal_name_sig, _primal_name @@ -575,14 +592,15 @@ _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. Meta.isexpr(arg, :(...), 1) && return _unconstrain(arg.args[1]) - error("malformed arguments: $arg") + return error("malformed arguments: $arg") end "turn both `a` and `::constraint` into `a::constraint` etc" function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) - error("malformed arguments: $arg") + Meta.isexpr(arg, :(...), 1) && + return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) + return error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 9c1378aab..c2bad7a77 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -81,7 +81,7 @@ LinearAlgebra.dot(::ZeroTangent, ::NoTangent) = ZeroTangent() Base.muladd(::ZeroTangent, x, y) = y Base.muladd(x, ::ZeroTangent, y) = y -Base.muladd(x, y, ::ZeroTangent) = x*y +Base.muladd(x, y, ::ZeroTangent) = x * y Base.muladd(::ZeroTangent, ::ZeroTangent, y) = y Base.muladd(x, ::ZeroTangent, ::ZeroTangent) = ZeroTangent() @@ -125,11 +125,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where P +function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P, typeof(data)}(data) + return Tangent{P,typeof(data)}(data) end -function Base.:+(a::P, d::Tangent{P}) where P +function Base.:+(a::P, d::Tangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -142,12 +142,12 @@ function Base.:+(a::P, d::Tangent{P}) where P end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where P = b + a +Base.:+(a::Tangent{P}, b::P) where {P} = b + a # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 differentials # Only of a differential and a scaling factor (generally `Real`) for T in (:Any,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x->s*x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x->x*s, tangent) + @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 216357e91..5993d32b4 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -17,15 +17,15 @@ Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) -Base.Broadcast.broadcasted(::Type{T}) where T<:AbstractZero = T() +Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() # Linear operators Base.adjoint(z::AbstractZero) = z Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z -Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) -(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T) +Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) +(::Type{T})(xs::AbstractZero...) where {T<:Number} = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index a2044fbe1..a6b9cc5f9 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -44,9 +44,15 @@ Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::Type{<:NotImplemented}) = throw(NotImplementedException(@not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" -))) +function Base.zero(::Type{<:NotImplemented}) + return throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing differentials of type `NotImplemented`" + ) + ), + ) +end Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) @@ -75,5 +81,5 @@ function Base.showerror(io::IO, e::NotImplementedException) if e.info !== nothing print(io, "\nInfo: ", e.info) end - return + return nothing end diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index e4bbfb8c8..bb91e431e 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -21,42 +21,42 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P, T} <: AbstractTangent +struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T end -function Tangent{P}(; kwargs...) where P +function Tangent{P}(; kwargs...) where {P} backing = (; kwargs...) # construct as NamedTuple - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end -function Tangent{P}(args...) where P - return Tangent{P, typeof(args)}(args) +function Tangent{P}(args...) where {P} + return Tangent{P,typeof(args)}(args) end -function Tangent{P}() where P<:Tuple +function Tangent{P}() where {P<:Tuple} backing = () - return Tangent{P, typeof(backing)}(backing) + return Tangent{P,typeof(backing)}(backing) end function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P, typeof(d)}(d) + return Tangent{P,typeof(d)}(d) end -function Base.:(==)(a::Tangent{P, T}, b::Tangent{P, T}) where {P, T} +function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P, T} +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P,T} all_fields = union(keys(backing(a)), keys(backing(b))) return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P, Q} = false +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where P +function Base.show(io::IO, tangent::Tangent{P}) where {P} print(io, "Tangent{") show(io, P) print(io, "}") @@ -68,15 +68,15 @@ function Base.show(io::IO, tangent::Tangent{P}) where P end end -function Base.getindex(tangent::Tangent{P, T}, idx::Int) where {P, T<:Union{Tuple, NamedTuple}} +function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getindex(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end -function Base.getindex(tangent::Tangent, idx) where {P, T<:AbstractDict} +function Base.getindex(tangent::Tangent, idx) where {P,T<:AbstractDict} return unthunk(getindex(backing(tangent), idx)) end @@ -84,7 +84,7 @@ function Base.getproperty(tangent::Tangent, idx::Int) back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) end -function Base.getproperty(tangent::Tangent{P, T}, idx::Symbol) where {P, T<:NamedTuple} +function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} hasfield(T, idx) || return ZeroTangent() return unthunk(getfield(backing(tangent), idx)) end @@ -99,26 +99,26 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) -Base.eltype(::Type{<:Tangent{<:Any, T}}) where T = eltype(T) +Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) - Tangent{typeof(rev_backing), typeof(rev_backing)}(rev_backing) + return Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) end function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} return Base.indexed_iterate(backing(tangent), i, state) end -function Base.map(f, tangent::Tangent{P, <:Tuple}) where P +function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} vals::Tuple = map(f, backing(tangent)) - return Tangent{P, typeof(vals)}(vals) + return Tangent{P,typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P, <:NamedTuple{L}}) where{P, L} +function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L, typeof(vals)}(vals) - return Tangent{P, typeof(named_vals)}(named_vals) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return Tangent{P,typeof(named_vals)}(named_vals) end -function Base.map(f, tangent::Tangent{P, <:Dict}) where {P<:Dict} +function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end @@ -140,26 +140,28 @@ backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) # For generic structs -function backing(x::T)::NamedTuple where T +function backing(x::T)::NamedTuple where {T} # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. # see https://github.com/JuliaLang/julia/issues/34283 if @generated - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names, Tuple{$(types...)}}($vals)) + vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) else - !isstructtype(T) && throw(DomainError(T, "backing can only be used on struct types")) + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) nfields = fieldcount(T) names = fieldnames(T) types = fieldtypes(T) - vals = ntuple(ii->getfield(x, ii), nfields) - return NamedTuple{names, Tuple{types...}}(vals) + vals = ntuple(ii -> getfield(x, ii), nfields) + return NamedTuple{names,Tuple{types...}}(vals) end end @@ -170,36 +172,38 @@ Return the canonical `Tangent` for the primal type `P`. The property names of the returned `Tangent` match the field names of the primal, and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. """ -function canonicalize(tangent::Tangent{P, <:NamedTuple{L}}) where {P,L} +function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} nil = _zeroed_backing(P) combined = merge(nil, backing(tangent)) if length(combined) !== fieldcount(P) - throw(ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))" - )) + throw( + ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", + ), + ) end - return Tangent{P, typeof(combined)}(combined) + return Tangent{P,typeof(combined)}(combined) end # Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple, <:Tuple}) = tangent +canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent # Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any, <:AbstractDict}) = tangent +canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent # Tangents of unspecified primal types (indicated by specifying exactly `Any`) # all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any, <:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:Tuple}) where {L} = tangent -canonicalize(tangent::Tangent{Any, <:AbstractDict}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:Tuple}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:AbstractDict}) where {L} = tangent """ _zeroed_backing(P) Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. """ -@generated function _zeroed_backing(::Type{P}) where P +@generated function _zeroed_backing(::Type{P}) where {P} nil_base = ntuple(fieldcount(P)) do i (fieldname(P, i), ZeroTangent()) end @@ -218,7 +222,7 @@ after an operation such as the addition of a primal to a tangent It should be overloaded, if `T` does not have a default constructor, or if `T` needs to maintain some invarients between its fields. """ -function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} +function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} # Tested and verified that that this avoids a ton of allocations if length(L) !== fieldcount(T) # if length is equal but names differ then we will catch that below anyway. @@ -233,12 +237,12 @@ function construct(::Type{T}, fields::NamedTuple{L}) where {T, L} end end -construct(::Type{T}, fields::T) where T<:NamedTuple = fields -construct(::Type{T}, fields::T) where T<:Tuple = fields +construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields +construct(::Type{T}, fields::T) where {T<:Tuple} = fields elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} # Rule of Tangent addition: any fields not present are implict hard Zeros # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. @@ -281,7 +285,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} end field => value end - return (;vals...) + return (; vals...) end end @@ -297,15 +301,16 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")"; color=:blue) println(io, "\nor overload") - printstyled(io, + printstyled( + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color=:blue + color=:blue, ) println(io, "\nor overload") printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) println(io, "\nOriginal Exception:") printstyled(io, err.original; color=:yellow) - println(io) + return println(io) end diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 16384d69e..e065bea62 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -67,7 +67,7 @@ end function LinearAlgebra.diagm( m, n, kv::Pair{<:Integer,<:AbstractThunk}, kvs::Pair{<:Integer,<:AbstractThunk}... ) - return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) + return diagm(m, n, (k => unthunk(v) for (k, v) in (kv, kvs...))...) end LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a)) @@ -197,7 +197,6 @@ Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a) - """ InplaceableThunk(add!::Function, val::Thunk) diff --git a/test/accumulation.jl b/test/accumulation.jl index 1b41fea55..a796b5289 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -27,7 +27,7 @@ end @testset "misc AbstractTangent subtypes" begin - @test 16 == add!!(12, @thunk(2*2)) + @test 16 == add!!(12, @thunk(2 * 2)) @test 16 == add!!(16, ZeroTangent()) @test 16 == add!!(16, NoTangent()) # Should this be an error? @@ -37,15 +37,15 @@ @testset "LHS Array (inplace)" begin @testset "RHS Array" begin A = [1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] end @testset "RHS StaticArray" begin - A = @SMatrix[1.0 2.0; 3.0 4.0] - accumuland = -1.0*ones(2,2) + A = @SMatrix [1.0 2.0; 3.0 4.0] + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 1.0; 2.0 3.0] @@ -53,7 +53,7 @@ @testset "RHS Diagonal" begin A = Diagonal([1.0, 2.0]) - accumuland = -1.0*ones(2,2) + accumuland = -1.0 * ones(2, 2) ret = add!!(accumuland, A) @test ret === accumuland # must be same object @test accumuland == [0.0 -1.0; -1.0 1.0] @@ -79,17 +79,17 @@ @testset "Unhappy Path" begin # wrong length - @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(2, 2)) # wrong shape - @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) + @test_throws DimensionMismatch add!!(ones(4, 4), ones(16)) # wrong type (adding scalar to array) @test_throws MethodError add!!(ones(4), 21.0) end end @testset "AbstractThunk $(typeof(thunk))" for thunk in ( - @thunk(-1.0*ones(2, 2)), - InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0*ones(2, 2))), + @thunk(-1.0 * ones(2, 2)), + InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))), ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] @@ -111,12 +111,12 @@ @testset "not actually inplace but said it was" begin # thunk should never be used in this test ithunk = InplaceableThunk(@thunk(@assert false)) do x - 77*ones(2, 2) # not actually inplace (also wrong) + 77 * ones(2, 2) # not actually inplace (also wrong) end accumuland = ones(2, 2) @assert ChainRulesCore.debug_mode() == false # without debug being enabled should return the result, not error - @test 77*ones(2, 2) == add!!(accumuland, ithunk) + @test 77 * ones(2, 2) == add!!(accumuland, ithunk) ChainRulesCore.debug_mode() = true # enable debug mode # with debug being enabled should error @@ -127,7 +127,7 @@ @testset "showerror BadInplaceException" begin BadInplaceException = ChainRulesCore.BadInplaceException - ithunk = InplaceableThunk(x̄->nothing, @thunk(@assert false)) + ithunk = InplaceableThunk(x̄ -> nothing, @thunk(@assert false)) msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) @test occursin("22", msg) diff --git a/test/config.jl b/test/config.jl index 466baed9a..58d943252 100644 --- a/test/config.jl +++ b/test/config.jl @@ -1,7 +1,7 @@ # Define a bunch of configs for testing purposes struct MostBoringConfig <: RuleConfig{Union{}} end -struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode, NoReverseMode}} +struct MockForwardsConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} forward_calls::Vector end MockForwardsConfig() = MockForwardsConfig([]) @@ -11,7 +11,7 @@ function ChainRulesCore.frule_via_ad(config::MockForwardsConfig, ȧrgs, f, args. return f(args...; kws...), ȧrgs end -struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode, HasReverseMode}} +struct MockReverseConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} reverse_calls::Vector end MockReverseConfig() = MockReverseConfig([]) @@ -22,8 +22,7 @@ function ChainRulesCore.rrule_via_ad(config::MockReverseConfig, f, args...; kws. return f(args...; kws...), pullback_via_ad end - -struct MockBothConfig <: RuleConfig{Union{HasForwardsMode, HasReverseMode}} +struct MockBothConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} forward_calls::Vector reverse_calls::Vector end @@ -47,18 +46,18 @@ end @testset "config.jl" begin @testset "basic fall to two arg verion for $Config" for Config in ( - MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig, + MostBoringConfig, MockForwardsConfig, MockReverseConfig, MockBothConfig ) counting_id_count = Ref(0) function counting_id(x) - counting_id_count[]+=1 + counting_id_count[] += 1 return x end function ChainRulesCore.rrule(::typeof(counting_id), x) counting_id_pullback(x̄) = x̄ return counting_id(x), counting_id_pullback end - function ChainRulesCore.frule((dself, dx),::typeof(counting_id), x) + function ChainRulesCore.frule((dself, dx), ::typeof(counting_id), x) return counting_id(x), dx end @testset "rrule" begin @@ -88,7 +87,7 @@ end end @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== frule( bconfig, (NoTangent(), NoTangent(), 21.5), do_thing_2, identity, 32.1 ) @@ -104,13 +103,12 @@ end return (NoTangent(), rrule_via_ad(config, f, x)...) end - @testset "$Config" for Config in (MostBoringConfig, MockForwardsConfig) @test nothing === rrule(Config(), do_thing_3, identity, 32.1) end @testset "$Config" for Config in (MockBothConfig, MockReverseConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_3, identity, 32.1) @test bconfig.reverse_calls == [(identity, (32.1,))] end @@ -130,14 +128,14 @@ end ẋ = one(x) y, ẏ = frule_via_ad(config, (NoTangent(), ẋ), f, x) - pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ*ȳ + pullback_via_forwards_ad(ȳ) = NoTangent(), NoTangent(), ẏ * ȳ return y, pullback_via_forwards_ad end function ChainRulesCore.rrule( - config::RuleConfig{>:Union{HasReverseMode, NoForwardsMode}}, + config::RuleConfig{>:Union{HasReverseMode,NoForwardsMode}}, ::typeof(do_thing_4), f, - x + x, ) y, f_pullback = rrule_via_ad(config, f, x) do_thing_4_pullback(ȳ) = (NoTangent(), f_pullback(ȳ)...) @@ -147,18 +145,18 @@ end @test nothing === rrule(MostBoringConfig(), do_thing_4, identity, 32.1) @testset "$Config" for Config in (MockBothConfig, MockForwardsConfig) - bconfig= Config() + bconfig = Config() @test nothing !== rrule(bconfig, do_thing_4, identity, 32.1) @test bconfig.forward_calls == [(identity, (32.1,))] end - rconfig= MockReverseConfig() + rconfig = MockReverseConfig() @test nothing !== rrule(rconfig, do_thing_4, identity, 32.1) @test rconfig.reverse_calls == [(identity, (32.1,))] end @testset "RuleConfig broadcasts like a scaler" begin - @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} + @test (MostBoringConfig() .=> (1, 2, 3)) isa NTuple{3,Pair{MostBoringConfig,Int}} end @testset "fallbacks" begin @@ -174,16 +172,16 @@ end # Test that incorrect use of the fallback rules correctly throws MethodError @test_throws MethodError frule() - @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(; kw="hello") @test_throws MethodError frule(sin) - @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(sin; kw="hello") @test_throws MethodError frule(MostBoringConfig()) @test_throws MethodError frule(MostBoringConfig(); kw="hello") @test_throws MethodError frule(MostBoringConfig(), sin) @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") @test_throws MethodError rrule() - @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(; kw="hello") @test_throws MethodError rrule(MostBoringConfig()) - @test_throws MethodError rrule(MostBoringConfig();kw="hello") + @test_throws MethodError rrule(MostBoringConfig(); kw="hello") end end diff --git a/test/deprecated.jl b/test/deprecated.jl index e69de29bb..8b1378917 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -0,0 +1 @@ + diff --git a/test/ignore_derivatives.jl b/test/ignore_derivatives.jl index 825287b9a..ad4fece9f 100644 --- a/test/ignore_derivatives.jl +++ b/test/ignore_derivatives.jl @@ -7,7 +7,7 @@ end @testset "function" begin f() = return 4.0 - y, ẏ = frule((1.0, ), ignore_derivatives, f) + y, ẏ = frule((1.0,), ignore_derivatives, f) @test y == f() @test ẏ == NoTangent() @@ -19,7 +19,7 @@ end @testset "argument" begin arg = 2.1 - y, ẏ = frule((1.0, ), ignore_derivatives, arg) + y, ẏ = frule((1.0,), ignore_derivatives, arg) @test y == arg @test ẏ == NoTangent() @@ -41,11 +41,11 @@ end @test pb(1.0) == (NoTangent(), NoTangent()) # when called - y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0)) + y, ẏ = frule((1.0,), ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test ẏ == NoTangent() - y, pb = rrule(ignore_derivatives, ()->mf(3.0)) + y, pb = rrule(ignore_derivatives, () -> mf(3.0)) @test y == mf(3.0) @test pb(1.0) == (NoTangent(), NoTangent()) end diff --git a/test/projection.jl b/test/projection.jl index ba61fb8da..cbfdcf6da 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -24,9 +24,9 @@ struct NoSuperType end # real / complex @test ProjectTo(1.0)(2.0 + 3im) === 2.0 @test ProjectTo(1.0 + 2.0im)(3.0) === 3.0 + 0.0im - @test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im - @test ProjectTo(2.0)(1+1im) === 1.0 - + @test ProjectTo(2.0 + 3.0im)(1 + 1im) === 1.0 + 1.0im + @test ProjectTo(2.0)(1 + 1im) === 1.0 + # storage @test ProjectTo(1)(pi) === pi @test ProjectTo(1 + im)(pi) === ComplexF64(pi) @@ -37,7 +37,8 @@ struct NoSuperType end @test ProjectTo(1.0)(2) === 2.0 # Tangents - ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im + ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(; re=1, im=NoTangent())) === + 1.0f0 + 0.0f0im end @testset "Dual" begin # some weird Real subtype that we should basically leave alone @@ -46,9 +47,8 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} - @test ProjectTo(1.0 + 1im)( - Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)) - ) isa Complex{<:Dual} + @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa + Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent @@ -99,7 +99,7 @@ struct NoSuperType end # arrays of other things @test ProjectTo([:x, :y]) isa ProjectTo{NoTangent} @test ProjectTo(Any['x', "y"]) isa ProjectTo{NoTangent} - @test ProjectTo([(1,2), (3,4), (5,6)]) isa ProjectTo{AbstractArray} + @test ProjectTo([(1, 2), (3, 4), (5, 6)]) isa ProjectTo{AbstractArray} @test ProjectTo(Any[1, 2])(1:2) == [1.0, 2.0] # projects each number. @test Tuple(ProjectTo(Any[1, 2 + 3im])(1:2)) === (1.0, 2.0 + 0.0im) @@ -126,18 +126,18 @@ struct NoSuperType end @testset "Base: Ref" begin pref = ProjectTo(Ref(2.0)) @test pref(Ref(3 + im)).x === 3.0 - @test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0 + @test pref(Tangent{Base.RefValue}(; x=3 + im)).x === 3.0 @test pref(4).x === 4.0 # also re-wraps scalars @test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue} pref2 = ProjectTo(Ref{Any}(6 + 7im)) @test pref2(Ref(8)).x === 8.0 + 0.0im - @test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im + @test pref2(Tangent{Base.RefValue}(; x=8)).x === 8.0 + 0.0im prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents @test prefvec(Ref(1:3)).x isa Vector{ComplexF64} - @test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64} - @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5)) + @test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64} + @test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5)) @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} @@ -341,13 +341,13 @@ struct NoSuperType end @testset "Tangent" begin x = 1:3.0 - dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()); + dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent()) @test ProjectTo(x)(dx) isa Tangent @test ProjectTo(x)(dx).step === 0.1 @test ProjectTo(x)(dx).offset isa AbstractZero pref = ProjectTo(Ref(2.0)) - dy = Tangent{typeof(Ref(2.0))}(x = 3+4im) + dy = Tangent{typeof(Ref(2.0))}(; x=3 + 4im) @test pref(dy) isa Tangent{<:Base.RefValue} @test pref(dy).x === 3.0 end @@ -365,21 +365,21 @@ struct NoSuperType end # Each "@test 33 > ..." is zero on nightly, 32 on 1.5. pvec = ProjectTo(rand(10^3)) - @test 0 == @ballocated $pvec(dx) setup=(dx = rand(10^3)) # pass through - @test 90 > @ballocated $pvec(dx) setup=(dx = rand(10^3, 1)) # reshape + @test 0 == @ballocated $pvec(dx) setup = (dx = rand(10^3)) # pass through + @test 90 > @ballocated $pvec(dx) setup = (dx = rand(10^3, 1)) # reshape @test 33 > @ballocated ProjectTo(x)(dx) setup = (x = rand(10^3); dx = rand(10^3)) # including construction padj = ProjectTo(adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = adjoint(rand(10^3))) - @test 0 == @ballocated $padj(dx) setup=(dx = transpose(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = adjoint(rand(10^3))) + @test 0 == @ballocated $padj(dx) setup = (dx = transpose(rand(10^3))) @test 33 > @ballocated ProjectTo(x')(dx') setup = (x = rand(10^3); dx = rand(10^3)) pdiag = ProjectTo(Diagonal(rand(10^3))) - @test 0 == @ballocated $pdiag(dx) setup=(dx = Diagonal(rand(10^3))) + @test 0 == @ballocated $pdiag(dx) setup = (dx = Diagonal(rand(10^3))) psymm = ProjectTo(Symmetric(rand(10^3, 10^3))) - @test_broken 0 == @ballocated $psymm(dx) setup=(dx = Symmetric(rand(10^3, 10^3))) # 64 + @test_broken 0 == @ballocated $psymm(dx) setup = (dx = Symmetric(rand(10^3, 10^3))) # 64 end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 0d6d98535..ec9549e7e 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -19,7 +19,7 @@ macro test_macro_throws(err_expr, expr) end end # Reuse `@test_throws` logic - if err!==nothing + if err !== nothing @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) @@ -37,7 +37,7 @@ struct NonDiffCounterExample end module NonDiffModuleExample - nondiff_2_1(x, y) = fill(7.5, 100)[x + y] +nondiff_2_1(x, y) = fill(7.5, 100)[x + y] end @testset "rule_definition_tools.jl" begin @@ -58,7 +58,7 @@ end res, pullback = rrule(nondiff_1_2, 3.1) @test res == (5.0, 3.0) @test isequal( - pullback(Tangent{Tuple{Float64, Float64}}(1.2, 3.2)), + pullback(Tangent{Tuple{Float64,Float64}}(1.2, 3.2)), (NoTangent(), NoTangent()), ) end @@ -81,7 +81,8 @@ end pointy_identity(x) = x @non_differentiable pointy_identity(::Vector{<:AbstractString}) - @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == (["2"], NoTangent()) + @test frule((ZeroTangent(), 1.2), pointy_identity, ["2"]) == + (["2"], NoTangent()) @test frule((ZeroTangent(), 1.2), pointy_identity, 2.0) == nothing res, pullback = rrule(pointy_identity, ["2"]) @@ -112,7 +113,8 @@ end @test res == 4.5 @test pullback(1.1) == (NoTangent(), NoTangent()) - @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, NoTangent()) + @test frule((ZeroTangent(), 11.1), kw_demo, 1.5; kw=3.0) == + (4.5, NoTangent()) end end @@ -121,7 +123,7 @@ end @test isequal( frule((ZeroTangent(), 1.2), NonDiffExample, 2.0), - (NonDiffExample(2.0), NoTangent()) + (NonDiffExample(2.0), NoTangent()), ) res, pullback = rrule(NonDiffExample, 2.0) @@ -151,7 +153,7 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), NoTangent()) @test frule((1, 1), fvarargs, 1, 2) == nothing - @test rrule(fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @@ -196,8 +198,8 @@ end @testset "Functors" begin (f::NonDiffExample)(y) = fill(7.5, 100)[f.x + y] @non_differentiable (::NonDiffExample)(::Any) - @test frule((Tangent{NonDiffExample}(x=1.2), 2.3), NonDiffExample(3), 2) == - (7.5, NoTangent()) + @test frule((Tangent{NonDiffExample}(; x=1.2), 2.3), NonDiffExample(3), 2) == + (7.5, NoTangent()) res, pullback = rrule(NonDiffExample(3), 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent()) @@ -205,8 +207,9 @@ end @testset "Module specified explicitly" begin @non_differentiable NonDiffModuleExample.nondiff_2_1(::Any, ::Any) - @test frule((ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2) == - (7.5, NoTangent()) + @test frule( + (ZeroTangent(), 1.2, 2.3), NonDiffModuleExample.nondiff_2_1, 3, 2 + ) == (7.5, NoTangent()) res, pullback = rrule(NonDiffModuleExample.nondiff_2_1, 3, 2) @test res == 7.5 @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) @@ -216,7 +219,7 @@ end # Where clauses are not supported. @test_macro_throws( ErrorException, - (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) + (@non_differentiable where_identity(::Vector{T}) where {T<:AbstractString}) ) end end @@ -224,32 +227,33 @@ end @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) - @scalar_rule(simo(x), 1f0, 2f0) + @scalar_rule(simo(x), 1.0f0, 2.0f0) y, simo_pb = rrule(simo, π) - @test simo_pb((10f0, 20f0)) == (NoTangent(), 50f0) + @test simo_pb((10.0f0, 20.0f0)) == (NoTangent(), 50.0f0) - y, ẏ = frule((NoTangent(), 50f0), simo, π) + y, ẏ = frule((NoTangent(), 50.0f0), simo, π) @test y == (π, 2π) - @test ẏ == Tangent{typeof(y)}(50f0, 100f0) + @test ẏ == Tangent{typeof(y)}(50.0f0, 100.0f0) # make sure type is exactly as expected: - @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} + @test ẏ isa Tangent{Tuple{Irrational{:π},Float64},Tuple{Float32,Float32}} xs, Ω = (3,), (3, 6) - @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == + ((1.0f0,), (2.0f0,)) end @testset "@scalar_rule projection" begin - make_imaginary(x) = im*x + make_imaginary(x) = im * x @scalar_rule make_imaginary(x) im # note: the === will make sure that these are Float64, not ComplexF64 - @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0*im) + @test (NoTangent(), 1.0) === rrule(make_imaginary, 2.0)[2](1.0 * im) @test (NoTangent(), 0.0) === rrule(make_imaginary, 2.0)[2](1.0) - @test (NoTangent(), 1.0+0.0im) === rrule(make_imaginary, 2.0im)[2](1.0*im) - @test (NoTangent(), 0.0-1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) + @test (NoTangent(), 1.0 + 0.0im) === rrule(make_imaginary, 2.0im)[2](1.0 * im) + @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end @testset "Regression tests against #276 and #265" begin @@ -263,7 +267,7 @@ end @scalar_rule(simo2(x), 1.0, 2.0) _, simo2_pb = rrule(simo2, 43.0) # make sure it infers: inferability implies type stability - @inferred simo2_pb(Tangent{Tuple{Float64, Float64}}(3.0, 6.0)) + @inferred simo2_pb(Tangent{Tuple{Float64,Float64}}(3.0, 6.0)) # Test no new globals were created @test length(names(ChainRulesCore; all=true)) == num_globals_before @@ -277,7 +281,8 @@ end end end - +#! format: off +# workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484 module IsolatedModuleForTestingScoping # check that rules can be defined by macros without any additional imports using ChainRulesCore: @scalar_rule, @non_differentiable @@ -336,3 +341,4 @@ module IsolatedModuleForTestingScoping end end end +#! format: on diff --git a/test/rules.jl b/test/rules.jl index d43ca42d2..54c10b160 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -28,8 +28,7 @@ end mixed_vararg(x, y, z...) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Any, Any, Vararg}, - ::typeof(mixed_vararg), x, y, z..., + dargs::Tuple{Any,Any,Any,Vararg}, ::typeof(mixed_vararg), x, y, z... ) Δx = dargs[2] Δy = dargs[3] @@ -39,16 +38,18 @@ end type_constraints(x::Int, y::Float64) = x + y function ChainRulesCore.frule( - (_, Δx, Δy)::Tuple{Any, Int, Float64}, - ::typeof(type_constraints), x::Int, y::Float64, + (_, Δx, Δy)::Tuple{Any,Int,Float64}, ::typeof(type_constraints), x::Int, y::Float64 ) return type_constraints(x, y), Δx + Δy end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) function ChainRulesCore.frule( - dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, - ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, + dargs::Tuple{Any,Float64,Real,Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), + x::Float64, + y::Real, + z::Vararg{Float64}, ) Δx = dargs[2] Δy = dargs[3] @@ -76,8 +77,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test hasmethod(rrule, Tuple{typeof(cool),String}) # Ensure those are the *only* methods that have been defined cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}]) + only_methods = Set([ + Tuple{typeof(rrule),typeof(cool),Number}, Tuple{typeof(rrule),typeof(cool),String} + ]) @test cool_methods == only_methods frx, cool_pushforward = frule((dself, 1), cool, 1) @@ -94,25 +96,24 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) rrx, nice_pullback = rrule(nice, 1) @test (NoTangent(), ZeroTangent()) === nice_pullback(1) - # Test that these run. Do not care about numerical correctness. @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) - @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == + (10.0, 10.0) @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing - @test(frule( - (nothing, 3.0, 2.0, 1.0, 0.0), - mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, - ) == (6.0, 6.0)) + @test( + frule( + (nothing, 3.0, 2.0, 1.0, 0.0), mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0 + ) == (6.0, 6.0) + ) # violates type constraints, thus an frule should not be found. - @test frule( - (nothing, 3, 2.0, 1.0, 5.0), - mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, - ) == nothing + @test frule((nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0) == + nothing @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) @@ -149,31 +150,34 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test_skip ∂xr ≈ real(∂x) end - @testset "@opt_out" begin first_oa(x, y) = x @scalar_rule(first_oa(x, y), (1, 0)) - @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where {T<:Float32} @opt_out( - ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float32} ) @testset "rrule" begin @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) - @test rrule(first_oa, 3f0, 4f0) === nothing + @test rrule(first_oa, 3.0f0, 4.0f0) === nothing - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m - m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 - end) + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_rrule)) do m + m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end @testset "frule" begin - @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) - @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing - - @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m - m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 - end) + @test frule((NoTangent(), 1, 0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1, 0), first_oa, 3.0f0, 4.0f0) === nothing + + @test !isempty( + Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float32} + end, + ) end end end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7e0ec9398..f8222d942 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -55,7 +55,7 @@ @test muladd(x, ZeroTangent(), ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), x, ZeroTangent()) === ZeroTangent() @test muladd(ZeroTangent(), ZeroTangent(), ZeroTangent()) === ZeroTangent() - + @test reim(z) === (ZeroTangent(), ZeroTangent()) @test real(z) === ZeroTangent() @test imag(z) === ZeroTangent() diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 694e43b53..26e6a2422 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -20,73 +20,73 @@ end @testset "Tangent" begin @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{}, Tuple{}} + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} end @testset "==" begin - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(x=0.1, y=2.5) - @test Tangent{Foo}(x=0.1, y=2.5) == Tangent{Foo}(y=2.5, x=0.1) - @test Tangent{Foo}(y=2.5, x=ZeroTangent()) == Tangent{Foo}(y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) + @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) - @test Tangent{Tuple{Float64,}}(2.0) == Tangent{Tuple{Float64,}}(2.0) + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2*1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @test Tangent{Foo}(;y=2.0,) == Tangent{Foo}(;x=ZeroTangent(), y=Float32(2.0),) + @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) end @testset "hash" begin - @test hash(Tangent{Foo}(x=0.1, y=2.5)) == hash(Tangent{Foo}(y=2.5, x=0.1)) - @test hash(Tangent{Foo}(y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(y=2.5)) + @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) + @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) end @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(x=2.5), :x) == true + @test keys(Tangent{Foo}(; x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(; x=2.5), :x) == true if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(x=2.5), :y) == false + @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false end - @test Tangent{Foo}(x=2.5).x == 2.5 - - @test keys(Tangent{Tuple{Float64,}}(2.0)) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64,}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64,}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 - - NT = NamedTuple{(:a, :b), Tuple{Float64, Float64}} - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getindex(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getindex(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 - - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :a) == 4.0 - @test getproperty(Tangent{NT}(a=(@thunk 2.0^2),), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(b=(@thunk 2.0^2),), 2) == 4.0 + @test Tangent{Foo}(; x=2.5).x == 2.5 + + @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - @test length(Tangent{Foo}(x=2.5)) == 1 - @test length(Tangent{Tuple{Float64,}}(2.0)) == 1 + @test length(Tangent{Foo}(; x=2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - @test eltype(Tangent{Foo}(x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64,}}(2.0)) == Float64 + @test eltype(Tangent{Foo}(; x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 # Testing iterate via collect - @test collect(Tangent{Foo}(x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64,}}(2.0)) == [2.0] + @test collect(Tangent{Foo}(; x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] # Test indexed_iterate ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function(tangent) + _unpack2tuple = function (tangent) a, b = tangent return (a, b) end @@ -96,21 +96,21 @@ end # Test getproperty is inferrable _unpacknamedtuple = tangent -> (tangent.x, tangent.y) if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) end end @testset "reverse" begin - c = Tangent{Tuple{Int, Int, String}}(1, 2, "something") - cr = Tangent{Tuple{String, Int, Int}}("something", 2, 1) + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) @test reverse(c) === cr # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(;x=1.0, y=2.0)) + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo, typeof(d)}(d) + cdict = Tangent{Foo,typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end @@ -119,10 +119,9 @@ end end @testset "conj" begin - @test conj(Tangent{Foo}(x=2.0+3.0im)) == Tangent{Foo}(x=2.0-3.0im) + @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) @test ==( - conj(Tangent{Tuple{Float64,}}(2.0+3.0im)), - Tangent{Tuple{Float64,}}(2.0-3.0im) + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) ) @test ==( conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), @@ -132,26 +131,20 @@ end @testset "canonicalize" begin # Testing iterate via collect - @test ==( - canonicalize(Tangent{Tuple{Float64,}}(2.0)), - Tangent{Tuple{Float64,}}(2.0) - ) + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - @test ==( - canonicalize(Tangent{Dict}(Dict(4 => 3))), - Tangent{Dict}(Dict(4 => 3)), - ) + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) # For structure it needs to match order and ZeroTangent() fill to match primal CFoo = Tangent{Foo} - @test canonicalize(CFoo(x=2.5, y=10)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10, x=2.5)) == CFoo(x=2.5, y=10) - @test canonicalize(CFoo(y=10)) == CFoo(x=ZeroTangent(), y=10) + @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) @testset "unspecified primal type" begin - c1 = Tangent{Any}(;a=1, b=2) + c1 = Tangent{Any}(; a=1, b=2) c2 = Tangent{Any}(1, 2) c3 = Tangent{Any}(Dict(4 => 3)) @@ -164,30 +157,27 @@ end @testset "+ with other composites" begin @testset "Structs" begin CFoo = Tangent{Foo} - @test CFoo(x=1.5) + CFoo(x=2.5) == CFoo(x=4.0) - @test CFoo(y=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=2.5) - @test CFoo(y=1.5, x=1.5) + CFoo(x=2.5) == CFoo(y=1.5, x=4.0) + @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) + @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) + @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) end @testset "Tuples" begin @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), - Tangent{Tuple{}, Tuple{}} + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} ) @test ( - Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64, Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64, Float64}}(2.0, 3.0) + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) end @testset "NamedTuples" begin - nt1 = (;a=1.5, b=0.0) - nt2 = (;a=0.0, b=2.5) - nt_sum = (a=1.5, b=2.5) - @test ( - Tangent{typeof(nt1)}(; nt1...) + - Tangent{typeof(nt2)}(; nt2...) - ) == Tangent{typeof(nt_sum)}(; nt_sum...) + make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) + t1 = make_tangent((; a=1.5, b=0.0)) + t2 = make_tangent((; a=0.0, b=2.5)) + t_sum = make_tangent((a=1.5, b=2.5)) + @test t1 + t2 == t_sum end @testset "Dicts" begin @@ -199,8 +189,8 @@ end @testset "Fields of type NotImplemented" begin CFoo = Tangent{Foo} - a = CFoo(x=1.5) - b = CFoo(x=@not_implemented("")) + a = CFoo(; x=1.5) + b = CFoo(; x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa CFoo @@ -215,8 +205,8 @@ end @test first(z) isa ChainRulesCore.NotImplemented end - a = Tangent{NamedTuple{(:x,)}}(x=1.5) - b = Tangent{NamedTuple{(:x,)}}(x=@not_implemented("")) + a = Tangent{NamedTuple{(:x,)}}(; x=1.5) + b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) for (x, y) in ((a, b), (b, a), (b, b)) z = x + y @test z isa Tangent{NamedTuple{(:x,)}} @@ -235,15 +225,15 @@ end @testset "+ with Primals" begin @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end @testset "Tuples" begin @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64, Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end @testset "NamedTuple" begin @@ -256,14 +246,14 @@ end @testset "Dicts" begin d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 =>5.0)) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end @testset "+ with Primals, with inner constructor" begin value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(x=2.0, x2=6.0) + diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) @testset "with and without debug mode" begin @assert ChainRulesCore.debug_mode() == false @@ -276,11 +266,10 @@ end ChainRulesCore.debug_mode() = false # disable it again end - # Now we define constuction for ChainRulesCore.jl's purposes: # It is going to determine the root quanity of the invarient function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2/2)/2 + x = (nt.x + nt.x2 / 2) / 2 return StructWithInvariant(x) end @test value + diff == StructWithInvariant(12.5) @@ -288,7 +277,7 @@ end end @testset "differential arithmetic" begin - c = Tangent{Foo}(y=1.5, x=2.5) + c = Tangent{Foo}(; y=1.5, x=2.5) @test NoTangent() * c == NoTangent() @test c * NoTangent() == NoTangent() @@ -310,14 +299,14 @@ end @testset "scaling" begin @test ( - 2 * Tangent{Foo}(y=1.5, x=2.5) - == Tangent{Foo}(y=3.0, x=5.0) - == Tangent{Foo}(y=1.5, x=2.5) * 2 + 2 * Tangent{Foo}(; y=1.5, x=2.5) == + Tangent{Foo}(; y=3.0, x=5.0) == + Tangent{Foo}(; y=1.5, x=2.5) * 2 ) @test ( - 2 * Tangent{Tuple{Float64, Float64}}(2.0, 4.0) - == Tangent{Tuple{Float64, Float64}}(4.0, 8.0) - == Tangent{Tuple{Float64, Float64}}(2.0, 4.0) * 2 + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 ) d = Tangent{Dict}(Dict(4 => 3.0)) two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) @@ -325,7 +314,7 @@ end end @testset "show" begin - @test repr(Tangent{Foo}(x=1,)) == "Tangent{Foo}(x = 1,)" + @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" # check for exact regex match not occurence( `^...$`) # and allowing optional whitespace (`\s?`) @test occursin( @@ -343,7 +332,8 @@ end @testset "Internals don't allocate a ton" begin bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 # weaker version of the above (which should pass on all versions) @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 @@ -354,6 +344,6 @@ end @testset "non-same-typed differential arithmetic" begin nt = (; a=1, b=2.0) c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1); + @test nt + c == (; a=1, b=2.1) end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 89461caa1..af4a747d1 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -141,7 +141,7 @@ # Check against accidential type piracy # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472 @test Base.which(diagm, Tuple{}()).module != ChainRulesCore - @test Base.which(diagm, Tuple{Int, Int}).module != ChainRulesCore + @test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore end @test tril(a) == tril(t) @test tril(a, 1) == tril(t, 1)