Skip to content

Commit 8281939

Browse files
committed
reapply formatter
1 parent 7e42d0e commit 8281939

File tree

8 files changed

+45
-32
lines changed

8 files changed

+45
-32
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ rng = Random.default_rng()
3333
Random.seed!(rng, seed)
3434

3535
model = Chain(Dense(2 => 2),
36-
DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false),
37-
Dense(2 => 2; use_bias=false)), NewtonRaphson()))
36+
DeepEquilibriumNetwork(
37+
Parallel(+, Dense(2 => 2; use_bias=false),
38+
Dense(2 => 2; use_bias=false)),
39+
NewtonRaphson()))
3840

3941
gdev = gpu_device()
4042
cdev = cpu_device()

docs/pages.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ pages = [
22
"Home" => "index.md",
33
"Tutorials" => [
44
"tutorials/basic_mnist_deq.md",
5-
"tutorials/reduced_dim_deq.md",
5+
"tutorials/reduced_dim_deq.md"
66
],
77
"API References" => "api.md",
8-
"References" => "references.md",
8+
"References" => "references.md"
99
]

docs/src/index.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ rng = Random.default_rng()
2525
Random.seed!(rng, seed)
2626
2727
model = Chain(Dense(2 => 2),
28-
DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false),
29-
Dense(2 => 2; use_bias=false)), NewtonRaphson()))
28+
DeepEquilibriumNetwork(
29+
Parallel(+, Dense(2 => 2; use_bias=false),
30+
Dense(2 => 2; use_bias=false)),
31+
NewtonRaphson()))
3032
3133
gdev = gpu_device()
3234
cdev = cpu_device()

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
44

55
```@example basic_mnist_deq
66
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
7+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
88
using MLDatasets: MNIST
99
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1010
@@ -65,7 +65,8 @@ function construct_model(solver; model_type::Symbol=:deq)
6565
Conv((4, 4), 64 => 64; stride=2, pad=1))
6666
6767
# The input layer of the DEQ
68-
deq_model = Chain(Parallel(+,
68+
deq_model = Chain(
69+
Parallel(+,
6970
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
7071
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
7172
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))

docs/src/tutorials/reduced_dim_deq.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size.
66

77
```@example reduced_dim_mnist
88
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
9-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
9+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
1010
using MLDatasets: MNIST
1111
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1212

src/DeepEquilibriumNetworks.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ import PrecompileTools: @recompile_invalidations
44

55
@recompile_invalidations begin
66
using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase,
7-
Statistics, SteadyStateDiffEq
7+
Statistics, SteadyStateDiffEq
88

99
import ChainRulesCore as CRC
1010
import ConcreteStructs: @concrete
1111
import ConstructionBase: constructorof
1212
import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer
1313
import SciMLBase: AbstractNonlinearAlgorithm,
14-
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
14+
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
1515
import TruncatedStacktraces: @truncate_stacktrace
1616
end
1717

@@ -23,7 +23,7 @@ include("utils.jl")
2323

2424
# Exports
2525
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
26-
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
27-
MultiScaleNeuralODE
26+
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
27+
MultiScaleNeuralODE
2828

2929
end

src/layers.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,10 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
152152
```julia
153153
julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
154154
155-
julia> model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
156-
Dense(2, 2; use_bias=false)), VCABM3(); verbose=false)
155+
julia> model = DeepEquilibriumNetwork(
156+
Parallel(+, Dense(2, 2; use_bias=false),
157+
Dense(2, 2; use_bias=false)),
158+
VCABM3(); verbose=false)
157159
DeepEquilibriumNetwork(
158160
model = Parallel(
159161
+
@@ -233,15 +235,17 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
233235
```julia
234236
julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve
235237
236-
julia> main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false),
237-
Dense(4 => 4, tanh; use_bias=false)), Dense(3 => 3, tanh), Dense(2 => 2, tanh),
238+
julia> main_layers = (
239+
Parallel(+, Dense(4 => 4, tanh; use_bias=false),
240+
Dense(4 => 4, tanh; use_bias=false)),
241+
Dense(3 => 3, tanh), Dense(2 => 2, tanh),
238242
Dense(1 => 1, tanh))
239243
(Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast))
240244
241245
julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
242-
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
243-
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
244-
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
246+
Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
247+
Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
248+
Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]
245249
4×4 Matrix{LuxCore.AbstractExplicitLayer}:
246250
NoOpLayer() … Dense(4 => 1, tanh_fast)
247251
Dense(3 => 4, tanh_fast) Dense(3 => 1, tanh_fast)

test/layers.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq,
2-
SciMLSensitivity, SciMLBase, Test
2+
SciMLSensitivity, SciMLBase, Test
33

44
include("test_utils.jl")
55

@@ -16,7 +16,7 @@ end
1616

1717
base_models = [
1818
Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)),
19-
Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1)),
19+
Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1))
2020
]
2121
init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)]
2222
x_sizes = [(2, 14), (3, 3, 1, 3)]
@@ -31,7 +31,8 @@ end
3131
@testset "Solver: $(__nameof(solver))" for solver in solvers,
3232
mtype in model_type, jacobian_regularization in jacobian_regularizations
3333

34-
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(base_models,
34+
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(
35+
base_models,
3536
init_models, x_sizes)
3637
model = if mtype === :deq
3738
DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
@@ -86,20 +87,20 @@ end
8687

8788
main_layers = [
8889
(Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)),
89-
__get_dense_layer(3 => 3), __get_dense_layer(2 => 2),
90-
__get_dense_layer(1 => 1)),
90+
__get_dense_layer(3 => 3), __get_dense_layer(2 => 2),
91+
__get_dense_layer(1 => 1))
9192
]
9293

9394
mapping_layers = [
9495
[NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1);
95-
__get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1);
96-
__get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1);
97-
__get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()],
96+
__get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1);
97+
__get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1);
98+
__get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()]
9899
]
99100

100101
init_layers = [
101102
(__get_dense_layer(4 => 4), __get_dense_layer(4 => 3), __get_dense_layer(4 => 2),
102-
__get_dense_layer(4 => 1)),
103+
__get_dense_layer(4 => 1))
103104
]
104105

105106
x_sizes = [(4, 3)]
@@ -113,16 +114,19 @@ end
113114

114115
for mtype in model_type, jacobian_regularization in jacobian_regularizations
115116
@testset "Solver: $(__nameof(solver))" for solver in solvers
116-
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
117+
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(
118+
main_layers,
117119
mapping_layers, init_layers, x_sizes, scales)
118120
model = if mtype === :deq
119121
MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
120122
solver, scale; jacobian_regularization)
121123
elseif mtype === :skipdeq
122-
MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
124+
MultiScaleSkipDeepEquilibriumNetwork(
125+
main_layer, mapping_layer, nothing,
123126
init_layer, solver, scale; jacobian_regularization)
124127
elseif mtype === :skipregdeq
125-
MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
128+
MultiScaleSkipDeepEquilibriumNetwork(
129+
main_layer, mapping_layer, nothing,
126130
solver, scale; jacobian_regularization)
127131
elseif mtype === :node
128132
solver isa SciMLBase.AbstractODEAlgorithm || continue

0 commit comments

Comments
 (0)