Skip to content

Commit a0914fe

Browse files
committed
Use ReTestItems for parallel testing
1 parent b064cbe commit a0914fe

25 files changed

+427
-395
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ whitespace_in_kwargs = false
33
format_docstrings = true
44
separate_kwargs_with_semicolon = true
55
format_markdown = true
6-
annotate_untyped_fields_with_any = false
6+
annotate_untyped_fields_with_any = false
7+
join_lines_based_on_source = false

.buildkite/pipeline.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,7 @@ steps:
5454
timeout_in_minutes: 240
5555

5656
env:
57+
RETESTITEMS_NWORKERS: 4
58+
RETESTITEMS_NWORKER_THREADS: 2
5759
SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA=="
5860
SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm"

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ jobs:
4343
env:
4444
GROUP: "CPU"
4545
JULIA_NUM_THREADS: 12
46+
RETESTITEMS_NWORKERS: 4
47+
RETESTITEMS_NWORKER_THREADS: 2
4648
- uses: julia-actions/julia-processcoverage@v1
4749
with:
4850
directories: src,ext
File renamed without changes.

Project.toml

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,78 @@
11
name = "DeepEquilibriumNetworks"
22
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
authors = ["Avik Pal <avikpal@mit.edu>"]
4-
version = "2.0.3"
4+
version = "2.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
910
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1011
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1112
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
13-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
15+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1516
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1819
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1920
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
20-
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2121

2222
[weakdeps]
2323
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2424
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2626

2727
[extensions]
28-
DeepEquilibriumNetworksLinearSolveSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
28+
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
2929
DeepEquilibriumNetworksZygoteExt = "Zygote"
3030

3131
[compat]
32-
ADTypes = "0.2.5"
32+
ADTypes = "0.2.5, 1"
33+
Aqua = "0.8.7"
3334
ChainRulesCore = "1"
35+
CommonSolve = "0.2.4"
3436
ConcreteStructs = "0.2"
3537
ConstructionBase = "1"
3638
DiffEqBase = "6.119"
39+
ExplicitImports = "1.4.1"
3740
FastClosures = "0.3"
38-
LinearAlgebra = "1"
41+
Functors = "0.4.10"
3942
LinearSolve = "2.21.2"
40-
Lux = "0.5.11"
43+
Lux = "0.5.37"
44+
LuxCUDA = "0.3.2"
45+
LuxCore = "0.1.14"
46+
LuxTestUtils = "0.1.15"
47+
NLsolve = "4.5.1"
48+
NonlinearSolve = "3.10.0"
49+
OrdinaryDiffEq = "6.74.1"
4150
PrecompileTools = "1"
42-
Random = "1"
51+
Random = "1.10"
52+
ReTestItems = "1.23.1"
4353
SciMLBase = "2"
4454
SciMLSensitivity = "7.43"
45-
Statistics = "1"
55+
StableRNGs = "1.0.2"
56+
Statistics = "1.10"
4657
SteadyStateDiffEq = "2"
47-
TruncatedStacktraces = "1.1"
48-
Zygote = "0.6.67"
49-
julia = "1.9"
58+
Test = "1.10"
59+
Zygote = "0.6.69"
60+
julia = "1.10"
61+
62+
[extras]
63+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
64+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
65+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
66+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
67+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
68+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
69+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
70+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
71+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
72+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
73+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
74+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
75+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
76+
77+
[targets]
78+
test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ Random.seed!(rng, seed)
3434

3535
model = Chain(Dense(2 => 2),
3636
DeepEquilibriumNetwork(
37-
Parallel(+, Dense(2 => 2; use_bias=false),
38-
Dense(2 => 2; use_bias=false)),
37+
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
3938
NewtonRaphson()))
4039

4140
gdev = gpu_device()

docs/make.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); style=:authoryear)
77

88
include("pages.jl")
99

10-
makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.",
11-
modules=[DeepEquilibriumNetworks], clean=true, doctest=true, linkcheck=true,
10+
makedocs(; sitename="Deep Equilibrium Networks",
11+
authors="Avik Pal et al.",
12+
modules=[DeepEquilibriumNetworks],
13+
clean=true,
14+
doctest=true,
15+
linkcheck=true,
1216
format=Documenter.HTML(; assets=["assets/favicon.ico"],
1317
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
14-
plugins=[bib], pages)
18+
plugins=[bib],
19+
pages)
1520

1621
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)

docs/pages.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
pages = [
2-
"Home" => "index.md",
3-
"Tutorials" => [
4-
"tutorials/basic_mnist_deq.md",
5-
"tutorials/reduced_dim_deq.md"
6-
],
7-
"API References" => "api.md",
8-
"References" => "references.md"
9-
]
1+
pages = ["Home" => "index.md",
2+
"Tutorials" => ["tutorials/basic_mnist_deq.md", "tutorials/reduced_dim_deq.md"],
3+
"API References" => "api.md", "References" => "references.md"]

docs/src/index.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ Random.seed!(rng, seed)
2626
2727
model = Chain(Dense(2 => 2),
2828
DeepEquilibriumNetwork(
29-
Parallel(+, Dense(2 => 2; use_bias=false),
30-
Dense(2 => 2; use_bias=false)),
29+
Parallel(+, Dense(2 => 2; use_bias=false), Dense(2 => 2; use_bias=false)),
3130
NewtonRaphson()))
3231
3332
gdev = gpu_device()

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ function construct_model(solver; model_type::Symbol=:deq)
6666
6767
# The input layer of the DEQ
6868
deq_model = Chain(
69-
Parallel(+,
70-
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
69+
Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
7170
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
7271
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
7372
@@ -79,11 +78,11 @@ function construct_model(solver; model_type::Symbol=:deq)
7978
init = missing
8079
end
8180
82-
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
83-
linsolve_kwargs=(; maxiters=10))
81+
deq = DeepEquilibriumNetwork(
82+
deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))
8483
85-
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
86-
Dense(64, 10))
84+
classifier = Chain(
85+
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
8786
8887
model = Chain(; down, deq, classifier)
8988
@@ -132,8 +131,8 @@ function accuracy(model, data, ps, st)
132131
return total_correct / total
133132
end
134133
135-
function train_model(solver, model_type; data_train=zip(x_train, y_train),
136-
data_test=zip(x_test, y_test))
134+
function train_model(
135+
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
137136
model, ps, st = construct_model(solver; model_type)
138137
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
139138

0 commit comments

Comments
 (0)