Skip to content

Commit bd3a278

Browse files
authored
Merge pull request #103 from SciML/auto-juliaformatter-pr
Automatic JuliaFormatter.jl run
2 parents ddc5efb + 3d4e169 commit bd3a278

File tree

8 files changed

+23
-23
lines changed

8 files changed

+23
-23
lines changed

src/chainrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star::T, u0::T, residual::T,
2-
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
2+
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
33
function ∇deep_equilibrium_solution(dsol)
44
return (∂∅, dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss, dsol.nfe)
55
end

src/layers/deq.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
@truncate_stacktrace DeepEquilibriumNetwork 1 2
4444

4545
function DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false,
46-
sensealg=SteadyStateAdjoint(), kwargs...)
46+
sensealg=SteadyStateAdjoint(), kwargs...)
4747
return DeepEquilibriumNetwork{jacobian_regularization}(model, solver, sensealg, kwargs)
4848
end
4949

@@ -112,15 +112,15 @@ end
112112
@truncate_stacktrace SkipDeepEquilibriumNetwork 1 2 3
113113

114114
function SkipDeepEquilibriumNetwork(model, shortcut, solver; sensealg=SteadyStateAdjoint(),
115-
jacobian_regularization::Bool=false, kwargs...)
115+
jacobian_regularization::Bool=false, kwargs...)
116116
return SkipDeepEquilibriumNetwork{jacobian_regularization}(model, shortcut, solver,
117117
sensealg, kwargs)
118118
end
119119

120120
_jacobian_regularization(::SkipDeepEquilibriumNetwork{J}) where {J} = J
121121

122122
function _get_initial_condition(deq::SkipDeepEquilibriumNetwork{J, M, Nothing}, x, ps,
123-
st) where {J, M}
123+
st) where {J, M}
124124
z, st_ = deq.model((zero(x), x), ps.model, st.model)
125125
return z, merge(st, (; model=st_))
126126
end

src/layers/evaluate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@generated function _evaluate_unrolled_model(::AbstractDEQs, model, z_star, x, ps, st,
2-
::Val{d}) where {d}
2+
::Val{d}) where {d}
33
calls = [:((z_star, st) = model((z_star, x), ps, st)) for _ in 1:d]
44
push!(calls, :(return z_star, st))
55
return Expr(:block, calls...)

src/layers/jacobian_stabilization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ Estimates the trace of the jacobian matrix wrt `z`.
2929
Stochastic Estimate of the trace of the Jacobian.
3030
"""
3131
function estimate_jacobian_trace(::Val{:reverse}, model::Lux.AbstractExplicitLayer,
32-
ps, st::NamedTuple, z::AbstractArray, x::AbstractArray, rng::AbstractRNG)
32+
ps, st::NamedTuple, z::AbstractArray, x::AbstractArray, rng::AbstractRNG)
3333
_, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z)
3434
vjp_z = back(_gaussian_like(rng, x))[1]
3535
return mean(abs2, vjp_z)
3636
end
3737

3838
function estimate_jacobian_trace(::Val{:finite_diff}, model::Lux.AbstractExplicitLayer,
39-
ps, st::NamedTuple, z::AbstractArray, x::AbstractArray, rng::AbstractRNG)
39+
ps, st::NamedTuple, z::AbstractArray, x::AbstractArray, rng::AbstractRNG)
4040
f = u -> model((u, x), ps, st)[1]
4141
res = convert(eltype(z), 0)
4242
epsilon = cbrt(eps(typeof(res)))

src/layers/mdeq.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwo
9494
end
9595

9696
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
97-
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
98-
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
97+
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
98+
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
9999
l1 = Parallel(nothing, main_layers...)
100100
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
101101

@@ -210,7 +210,7 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
210210
end
211211

212212
function MultiScaleSkipDeepEquilibriumNetwork(model::MultiScaleInputLayer{N},
213-
args...) where {N}
213+
args...) where {N}
214214
return MultiScaleSkipDeepEquilibriumNetwork{N}(model, args...)
215215
end
216216

@@ -225,9 +225,9 @@ function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumN
225225
end
226226

227227
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
228-
post_fuse_layer::Union{Nothing, Tuple}, shortcut_layers::Union{Nothing, Tuple},
229-
solver, scales::NTuple{N, NTuple{L, Int64}};
230-
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
228+
post_fuse_layer::Union{Nothing, Tuple}, shortcut_layers::Union{Nothing, Tuple},
229+
solver, scales::NTuple{N, NTuple{L, Int64}};
230+
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
231231
l1 = Parallel(nothing, main_layers...)
232232
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
233233
shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...)
@@ -247,7 +247,7 @@ end
247247
_jacobian_regularization(::MultiScaleSkipDeepEquilibriumNetwork) = false
248248

249249
function _get_initial_condition(deq::MultiScaleSkipDeepEquilibriumNetwork{N, M, Nothing},
250-
x, ps, st) where {N, M}
250+
x, ps, st) where {N, M}
251251
u0, st = _get_zeros_initial_condition_mdeq(deq.scales, x, st)
252252
z, st_ = deq.model((u0, x), ps.model, st.model)
253253
return z, merge(st, (; model=st_))
@@ -330,8 +330,8 @@ model(x, ps, st)
330330
See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref)
331331
"""
332332
function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
333-
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
334-
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
333+
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
334+
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
335335
l1 = Parallel(nothing, main_layers...)
336336
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
337337

@@ -361,7 +361,7 @@ end
361361

362362
# Shared Functions
363363
@generated function _get_zeros_initial_condition_mdeq(::Val{scales}, x::AbstractArray{T, N},
364-
st::NamedTuple{fields}) where {scales, T, N, fields}
364+
st::NamedTuple{fields}) where {scales, T, N, fields}
365365
sz = sum(prod.(scales))
366366
calls = []
367367
if :initial_condition fields
@@ -377,6 +377,6 @@ end
377377
CRC.@non_differentiable _get_zeros_initial_condition_mdeq(::Any...)
378378

379379
@inline function _postprocess_output(deq::Union{MultiScaleDeepEquilibriumNetwork,
380-
MultiScaleSkipDeepEquilibriumNetwork, MultiScaleNeuralODE}, z_star)
380+
MultiScaleSkipDeepEquilibriumNetwork, MultiScaleNeuralODE}, z_star)
381381
return split_and_reshape(z_star, deq.split_idxs, deq.scales)
382382
end

src/solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ struct ContinuousDEQSolver{A <: DynamicSS} <: AbstractDEQSolver
3030
end
3131

3232
function ContinuousDEQSolver(alg=VCAB3(); mode=NLSolveTerminationMode.RelSafeBest,
33-
abstol=1.0f-8, reltol=1.0f-6, abstol_termination=abstol, reltol_termination=reltol,
34-
tspan=Inf32, kwargs...)
33+
abstol=1.0f-8, reltol=1.0f-6, abstol_termination=abstol, reltol_termination=reltol,
34+
tspan=Inf32, kwargs...)
3535
termination_condition = NLSolveTerminationCondition(mode; abstol=abstol_termination,
3636
reltol=reltol_termination, kwargs...)
3737
return ContinuousDEQSolver(DynamicSS(alg; abstol, reltol, tspan, termination_condition))
@@ -81,7 +81,7 @@ end
8181
@truncate_stacktrace EquilibriumSolution 1 2
8282

8383
function SciMLBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSolver,
84-
args...; kwargs...)
84+
args...; kwargs...)
8585
# FIXME: Remove this handle
8686
sol = solve(prob, alg.alg, args...; kwargs...)
8787

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ DEQs.split_and_reshape(x, split_idxs, shapes)
2727
```
2828
"""
2929
@generated function split_and_reshape(x::AbstractMatrix, ::Val{idxs},
30-
::Val{shapes}) where {idxs, shapes}
30+
::Val{shapes}) where {idxs, shapes}
3131
dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)]
3232
varnames = [gensym("x_view") for _ in dims]
3333
calls = []

test/adjoint.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function loss_function(model::DEQs.AbstractSkipDeepEquilibriumNetwork, x, ps, st
1515
end
1616

1717
function loss_function(model::Union{MultiScaleDeepEquilibriumNetwork, MultiScaleNeuralODE},
18-
x, ps, st)
18+
x, ps, st)
1919
y, st_ = model(x, ps, st)
2020
return sum(sum, y) + st_.solution.jacobian_loss
2121
end

0 commit comments

Comments
 (0)