Skip to content

Commit 031deac

Browse files
committed
Test with the new frules
1 parent 908c224 commit 031deac

8 files changed

+19
-14
lines changed

Manifest.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.2"
44
manifest_format = "2.0"
5-
project_hash = "df8a9208b4276382055ff54a66a4252730918e13"
5+
project_hash = "914538f40e552ac89a85de7921db9eaf76294f1a"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224"
@@ -574,9 +574,11 @@ version = "0.1.20"
574574

575575
[[deps.LuxLib]]
576576
deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"]
577-
git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf"
577+
git-tree-sha1 = "8143e3dbdcfff587e9595b58c4b637e74c090fbf"
578+
repo-rev = "ap/more_frules"
579+
repo-url = "https://github.com/LuxDL/LuxLib.jl.git"
578580
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
579-
version = "0.3.16"
581+
version = "0.3.17"
580582

581583
[deps.LuxLib.extensions]
582584
LuxLibAMDGPUExt = "AMDGPU"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1313
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1414
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1515
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
16+
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
1617
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ function train_model(
138138
for (i, (x, y)) in enumerate(data_train)
139139
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
140140
Optimisers.update!(opt_st, ps, res.grad[3])
141-
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
141+
i % 50 == 1 &&
142+
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
142143
end
143144
144145
acc = accuracy(model, data_test, ps, model_st.st) * 100
@@ -151,7 +152,8 @@ function train_model(
151152
for (i, (x, y)) in enumerate(data_train)
152153
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
153154
Optimisers.update!(opt_st, ps, res.grad[3])
154-
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
155+
i % 50 == 1 &&
156+
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
155157
end
156158
157159
acc = accuracy(model, data_test, ps, model_st.st) * 100

docs/src/tutorials/reduced_dim_deq.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ function train_model(
132132
for (i, (x, y)) in enumerate(data_train)
133133
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
134134
Optimisers.update!(opt_st, ps, res.grad[3])
135-
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
135+
i % 50 == 1 &&
136+
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
136137
end
137138
138139
acc = accuracy(model, data_test, ps, model_st.st) * 100
@@ -145,7 +146,8 @@ function train_model(
145146
for (i, (x, y)) in enumerate(data_train)
146147
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
147148
Optimisers.update!(opt_st, ps, res.grad[3])
148-
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
149+
i % 50 == 1 &&
150+
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
149151
end
150152
151153
acc = accuracy(model, data_test, ps, model_st.st) * 100

ext/DeepEquilibriumNetworksZygoteExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ function CRC.rrule(
5050
end
5151

5252
## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
53-
## FIXME: This will be broken in the new Lux release let's fix this
5453
function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng)
5554
return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng))
5655
end

src/layers.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,8 @@ julia> model(x, ps, st);
314314
315315
```
316316
"""
317-
function MultiScaleDeepEquilibriumNetwork(
318-
main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple},
319-
solver, scales; kwargs...)
317+
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
318+
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...)
320319
l1 = Parallel(nothing, main_layers...)
321320
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
322321

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ CRC.@non_differentiable __zeros_init(::Any, ::Any)
8787
## Don't rely on SciMLSensitivity's choice
8888
@inline __default_sensealg(prob) = nothing
8989

90-
@inline function __gaussian_like(rng::AbstractRNG, x)
91-
y = similar(x)
90+
@inline function __gaussian_like(rng::AbstractRNG, x::AbstractArray)
91+
y = similar(x)::typeof(x)
9292
randn!(rng, y)
9393
return y
9494
end

test/layers_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] :
3535
_jacobian_regularizations
3636

37-
@testset "Solver: $(__nameof(solver))" for solver in SOLVERS,
37+
@testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
3838
mtype in model_type,
3939
jacobian_regularization in jacobian_regularizations
4040

0 commit comments

Comments
 (0)