Skip to content

Commit 47bcaa9

Browse files
authored
Merge pull request #152 from SciML/ap/unbreak_nested_ad
Housekeeping + Use Faster Nested AD
2 parents b064cbe + 7d9c2fa commit 47bcaa9

26 files changed

+513
-476
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ whitespace_in_kwargs = false
33
format_docstrings = true
44
separate_kwargs_with_semicolon = true
55
format_markdown = true
6-
annotate_untyped_fields_with_any = false
6+
annotate_untyped_fields_with_any = false
7+
join_lines_based_on_source = false

.buildkite/pipeline.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,7 @@ steps:
5454
timeout_in_minutes: 240
5555

5656
env:
57+
RETESTITEMS_NWORKERS: 4
58+
RETESTITEMS_NWORKER_THREADS: 2
5759
SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA=="
5860
SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm"

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ jobs:
4343
env:
4444
GROUP: "CPU"
4545
JULIA_NUM_THREADS: 12
46+
RETESTITEMS_NWORKERS: 4
47+
RETESTITEMS_NWORKER_THREADS: 2
4648
- uses: julia-actions/julia-processcoverage@v1
4749
with:
4850
directories: src,ext
File renamed without changes.

Project.toml

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,80 @@
11
name = "DeepEquilibriumNetworks"
22
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
authors = ["Avik Pal <avikpal@mit.edu>"]
4-
version = "2.0.3"
4+
version = "2.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
910
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1011
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1112
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
13-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
15+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1516
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1819
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1920
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
20-
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2121

2222
[weakdeps]
23+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2324
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2425
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2526
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2627

2728
[extensions]
28-
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
29-
DeepEquilibriumNetworksZygoteExt = "Zygote"
29+
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
30+
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]
3031

3132
[compat]
32-
ADTypes = "0.2.5"
33+
ADTypes = "0.2.5, 1"
34+
Aqua = "0.8.7"
3335
ChainRulesCore = "1"
36+
CommonSolve = "0.2.4"
3437
ConcreteStructs = "0.2"
3538
ConstructionBase = "1"
3639
DiffEqBase = "6.119"
40+
ExplicitImports = "1.4.1"
3741
FastClosures = "0.3"
38-
LinearAlgebra = "1"
42+
ForwardDiff = "0.10.36"
43+
Functors = "0.4.10"
3944
LinearSolve = "2.21.2"
40-
Lux = "0.5.11"
45+
Lux = "0.5.38"
46+
LuxCUDA = "0.3.2"
47+
LuxCore = "0.1.14"
48+
LuxTestUtils = "0.1.15"
49+
NLsolve = "4.5.1"
50+
NonlinearSolve = "3.10.0"
51+
OrdinaryDiffEq = "6.74.1"
4152
PrecompileTools = "1"
42-
Random = "1"
53+
Random = "1.10"
54+
ReTestItems = "1.23.1"
4355
SciMLBase = "2"
4456
SciMLSensitivity = "7.43"
45-
Statistics = "1"
57+
StableRNGs = "1.0.2"
58+
Statistics = "1.10"
4659
SteadyStateDiffEq = "2"
47-
TruncatedStacktraces = "1.1"
48-
Zygote = "0.6.67"
49-
julia = "1.9"
60+
Test = "1.10"
61+
Zygote = "0.6.69"
62+
julia = "1.10"
63+
64+
[extras]
65+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
66+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
67+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
68+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
69+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
70+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
71+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
72+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
73+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
74+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
75+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
76+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
77+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
78+
79+
[targets]
80+
test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ Random.seed!(rng, seed)
3434

3535
model = Chain(Dense(2 => 2),
3636
DeepEquilibriumNetwork(
37-
Parallel(+, Dense(2 => 2; use_bias=false),
38-
Dense(2 => 2; use_bias=false)),
37+
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
3938
NewtonRaphson()))
4039

4140
gdev = gpu_device()

docs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
[deps]
2+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
23
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
56
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
6-
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
77
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
88
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
99
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
1010
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1111
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1212
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1313
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
14+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1516
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -21,7 +22,6 @@ DeepEquilibriumNetworks = "2"
2122
Documenter = "1"
2223
DocumenterCitations = "1"
2324
LinearSolve = "2"
24-
LoggingExtras = "1"
2525
Lux = "0.5"
2626
LuxCUDA = "0.3"
2727
MLDataUtils = "0.5"

docs/make.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear)
77

88
include("pages.jl")
99

10-
makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.",
11-
modules=[DeepEquilibriumNetworks], clean=true, doctest=true, linkcheck=true,
10+
makedocs(; sitename="Deep Equilibrium Networks",
11+
authors="Avik Pal et al.",
12+
modules=[DeepEquilibriumNetworks],
13+
clean=true,
14+
doctest=true,
15+
linkcheck=true,
1216
format=Documenter.HTML(; assets=["assets/favicon.ico"],
1317
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
14-
plugins=[bib], pages)
18+
plugins=[bib],
19+
pages)
1520

1621
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)

docs/pages.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
pages = [
2-
"Home" => "index.md",
3-
"Tutorials" => [
4-
"tutorials/basic_mnist_deq.md",
5-
"tutorials/reduced_dim_deq.md"
6-
],
7-
"API References" => "api.md",
8-
"References" => "references.md"
9-
]
1+
pages = ["Home" => "index.md",
2+
"Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"],
3+
"API References" => "api.md", "References" => "references.md"]

docs/src/index.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ Random.seed!(rng, seed)
2626
2727
model = Chain(Dense(2 => 2),
2828
DeepEquilibriumNetwork(
29-
Parallel(+, Dense(2 => 2; use_bias=false),
30-
Dense(2 => 2; use_bias=false)),
29+
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
3130
NewtonRaphson()))
3231
3332
gdev = gpu_device()

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
44

55
```@example basic_mnist_deq
66
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
7+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
88
using MLDatasets: MNIST
99
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1010
@@ -20,18 +20,6 @@ const cdev = cpu_device()
2020
const gdev = gpu_device()
2121
```
2222

23-
SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
24-
it with the following logger
25-
26-
```@example basic_mnist_deq
27-
function remove_syms_warning(log_args)
28-
return log_args.message !=
29-
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
30-
end
31-
32-
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
33-
```
34-
3523
We can now construct our dataloader.
3624

3725
```@example basic_mnist_deq
@@ -66,8 +54,7 @@ function construct_model(solver; model_type::Symbol=:deq)
6654
6755
# The input layer of the DEQ
6856
deq_model = Chain(
69-
Parallel(+,
70-
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
57+
Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
7158
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
7259
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
7360
@@ -79,11 +66,11 @@ function construct_model(solver; model_type::Symbol=:deq)
7966
init = missing
8067
end
8168
82-
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
83-
linsolve_kwargs=(; maxiters=10))
69+
deq = DeepEquilibriumNetwork(
70+
deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))
8471
85-
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
86-
Dense(64, 10))
72+
classifier = Chain(
73+
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
8774
8875
model = Chain(; down, deq, classifier)
8976
@@ -95,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq)
9582
x = randn(rng, Float32, 28, 28, 1, 128)
9683
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
9784
98-
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
99-
@info "warming up forward pass"
85+
model_ = StatefulLuxLayer(model, ps, st)
86+
@printf "[%s] warming up forward pass\n" string(now())
10087
logitcrossentropy(model_, x, ps, y)
101-
@info "warming up backward pass"
88+
@printf "[%s] warming up backward pass\n" string(now())
10289
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
103-
@info "warmup complete"
90+
@printf "[%s] warmup complete\n" string(now())
10491
10592
return model, ps, st
10693
end
@@ -122,7 +109,7 @@ classify(x) = argmax.(eachcol(x))
122109
function accuracy(model, data, ps, st)
123110
total_correct, total = 0, 0
124111
st = Lux.testmode(st)
125-
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
112+
model = StatefulLuxLayer(model, ps, st)
126113
for (x, y) in data
127114
target_class = classify(cdev(y))
128115
predicted_class = classify(cdev(model(x)))
@@ -132,51 +119,48 @@ function accuracy(model, data, ps, st)
132119
return total_correct / total
133120
end
134121
135-
function train_model(solver, model_type; data_train=zip(x_train, y_train),
136-
data_test=zip(x_test, y_test))
122+
function train_model(
123+
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
137124
model, ps, st = construct_model(solver; model_type)
138-
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
125+
model_st = StatefulLuxLayer(model, nothing, st)
139126
140-
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
127+
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
141128
142129
opt_st = Optimisers.setup(Adam(0.001), ps)
143130
144131
acc = accuracy(model, data_test, ps, st) * 100
145-
@info "Starting Accuracy: $(acc)"
132+
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
146133
147-
@info "Pretrain with unrolling to a depth of 5"
134+
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
148135
st = Lux.update_state(st, :fixed_depth, Val(5))
149-
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
136+
model_st = StatefulLuxLayer(model, ps, st)
150137
151138
for (i, (x, y)) in enumerate(data_train)
152139
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
153140
Optimisers.update!(opt_st, ps, res.grad[3])
154-
if i % 50 == 1
155-
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
156-
end
141+
i % 50 == 1 &&
142+
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
157143
end
158144
159145
acc = accuracy(model, data_test, ps, model_st.st) * 100
160-
@info "Pretraining complete. Accuracy: $(acc)"
146+
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
161147
162148
st = Lux.update_state(st, :fixed_depth, Val(0))
163-
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
149+
model_st = StatefulLuxLayer(model, ps, st)
164150
165151
for epoch in 1:3
166152
for (i, (x, y)) in enumerate(data_train)
167153
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
168154
Optimisers.update!(opt_st, ps, res.grad[3])
169-
if i % 50 == 1
170-
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
171-
end
155+
i % 50 == 1 &&
156+
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
172157
end
173158
174159
acc = accuracy(model, data_test, ps, model_st.st) * 100
175-
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
160+
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
176161
end
177162
178-
@info "Training complete."
179-
println()
163+
@printf "[%s] Training complete.\n" string(now())
180164
181165
return model, ps, st
182166
end
@@ -188,19 +172,15 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
188172
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
189173

190174
```@example basic_mnist_deq
191-
with_logger(filtered_logger) do
192-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
193-
end
175+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
194176
nothing # hide
195177
```
196178

197179
We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
198180
which tend to be quite fast for continuous Neural Network problems.
199181

200182
```@example basic_mnist_deq
201-
with_logger(filtered_logger) do
202-
train_model(VCAB3(), :deq)
203-
end
183+
train_model(VCAB3(), :deq);
204184
nothing # hide
205185
```
206186

0 commit comments

Comments
 (0)