Skip to content

Commit 1e414cd

Browse files
committed
Remove uses of Static
1 parent 4d6c1ed commit 1e414cd

File tree

9 files changed

+25
-59
lines changed

9 files changed

+25
-59
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1818
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1919
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2020
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
21-
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2221
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2322
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
2423
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
@@ -38,7 +37,6 @@ Reexport = "1"
3837
SciMLBase = "2"
3938
SciMLSensitivity = "7.43"
4039
Setfield = "1"
41-
Static = "0.6, 0.7, 0.8"
4240
SteadyStateDiffEq = "1.16"
4341
TruncatedStacktraces = "1.1"
4442
Zygote = "0.6.34"

docs/src/manual/misc.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44
DeepEquilibriumSolution
55
EquilibriumSolution
66
DeepEquilibriumNetworks.split_and_reshape
7-
DeepEquilibriumNetworks.init_identity_matrix
87
DeepEquilibriumNetworks.estimate_jacobian_trace
98
```

src/DeepEquilibriumNetworks.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import Reexport: @reexport
44

55
@reexport using Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity
66

7-
using DiffEqBase, LinearAlgebra, LinearSolve, MLUtils, Random, SciMLBase, SciMLSensitivity,
8-
Setfield, Static, Statistics, SteadyStateDiffEq, Zygote, ZygoteRules
7+
using DiffEqBase, LinearAlgebra, LinearSolve, MLUtils, Random, SciMLBase,
8+
Setfield, Statistics, SteadyStateDiffEq, Zygote
99

1010
import DiffEqBase: AbstractSteadyStateProblem
1111
import SciMLBase: AbstractNonlinearSolution, AbstractSteadyStateAlgorithm
@@ -37,6 +37,7 @@ include("chainrules.jl")
3737
# Start of Weird Patches
3838
# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
3939
# needed.
40+
using ZygoteRules
4041
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
4142
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))
4243
# End of Weird Patches

src/layers/mdeq.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ end
77

88
@truncate_stacktrace MultiScaleInputLayer 1 2
99

10-
function MultiScaleInputLayer(model, split_idxs, scales)
11-
return MultiScaleInputLayer{length(scales)}(model, split_idxs, scales)
10+
function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S}
11+
return MultiScaleInputLayer{length(S)}(model, split_idxs, scales)
1212
end
1313

1414
@generated function (m::MultiScaleInputLayer{N})(z, ps, st) where {N}
@@ -99,8 +99,8 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma
9999
l1 = Parallel(nothing, main_layers...)
100100
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
101101

102-
scales = static(scales)
103-
split_idxs = static(Tuple(vcat(0, cumsum(prod.(scales))...)))
102+
scales = Val(scales)
103+
split_idxs = Val(Tuple(vcat(0, cumsum(prod.(SciMLBase._unwrap_val(scales)))...)))
104104
if post_fuse_layer === nothing
105105
model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales)
106106
else
@@ -231,8 +231,8 @@ function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers
231231
l1 = Parallel(nothing, main_layers...)
232232
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
233233
shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...)
234-
scales = static(scales)
235-
split_idxs = static(Tuple(vcat(0, cumsum(prod.(scales))...)))
234+
scales = Val(scales)
235+
split_idxs = Val(Tuple(vcat(0, cumsum(prod.(SciMLBase._unwrap_val(scales)))...)))
236236
if post_fuse_layer === nothing
237237
model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales)
238238
else
@@ -337,8 +337,8 @@ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
337337
l1 = Parallel(nothing, main_layers...)
338338
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
339339

340-
scales = static(scales)
341-
split_idxs = static(Tuple(vcat(0, cumsum(prod.(scales))...)))
340+
scales = Val(scales)
341+
split_idxs = Val(Tuple(vcat(0, cumsum(prod.(SciMLBase._unwrap_val(scales)))...)))
342342
if post_fuse_layer === nothing
343343
model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales)
344344
else
@@ -362,9 +362,8 @@ end
362362
@inline _fix_solution_output(::MultiScaleNeuralODE, x) = x[end]
363363

364364
# Shared Functions
365-
@generated function _get_zeros_initial_condition_mdeq(::S, x::AbstractArray{T, N},
366-
st::NamedTuple{fields}) where {S, T, N, fields}
367-
scales = known(S)
365+
@generated function _get_zeros_initial_condition_mdeq(::Val{scales}, x::AbstractArray{T, N},
366+
st::NamedTuple{fields}) where {scales, T, N, fields}
368367
sz = sum(prod.(scales))
369368
calls = []
370369
if :initial_condition fields

src/utils.jl

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# For MultiScale DEQs
22
"""
3-
split_and_reshape(x::AbstractMatrix, ::Sidxs, ::Sshapes)
3+
split_and_reshape(x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes}
44
55
Splits up the AbstractMatrix into chunks and reshapes them.
66
77
## Arguments
88
99
- `x`: Matrix to be split up.
10-
- `Sidxs`: Indices to partition the array at. (must be a `static` type).
11-
- `Sshapes`: Shapes to reshape the split the arrays. (must be a `static` type).
10+
- `Sidxs`: Indices to partition the array at. (must be a `Val` type).
11+
- `Sshapes`: Shapes to reshape the split the arrays. (must be a `Val` type).
1212
1313
## Example
1414
@@ -20,14 +20,14 @@ x2 = fill!(zeros(Float32, 2, 4), 0.5f0)
2020
x3 = zeros(Float32, 1, 4)
2121
2222
x = vcat(x1, x2, x3)
23-
split_idxs = static(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1))))
24-
shapes = static((size(x1, 1), size(x2, 1), size(x3, 1)))
23+
split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1))))
24+
shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1)))
2525
2626
DEQs.split_and_reshape(x, split_idxs, shapes)
2727
```
2828
"""
29-
@generated function split_and_reshape(x::AbstractMatrix, ::T, ::S) where {T, S}
30-
idxs, shapes = known(T), known(S)
29+
@generated function split_and_reshape(x::AbstractMatrix, ::Val{idxs},
30+
::Val{shapes}) where {idxs, shapes}
3131
dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)]
3232
varnames = [gensym("x_view") for _ in dims]
3333
calls = []
@@ -37,21 +37,3 @@ DEQs.split_and_reshape(x, split_idxs, shapes)
3737
push!(calls, :(return tuple($(varnames...))))
3838
return Expr(:block, calls...)
3939
end
40-
41-
# General Utils
42-
"""
43-
init_identity_matrix(x::AbstractArray, scale::T=1)
44-
45-
Create an identity matrix of shape `[length(x), length(x)]` and placed on the same device
46-
as `x`, and scale the matrix by `scale`.
47-
"""
48-
@inline function init_identity_matrix(x::AbstractArray{T}, scale::T=T(1)) where {T}
49-
x_ = vec(x)
50-
return _init_identity_matrix!(x_ .* x_', scale)
51-
end
52-
53-
@inline function _init_identity_matrix!(x::AbstractMatrix{T}, scale::T=T(1)) where {T}
54-
x .= zero(T)
55-
view(x, LinearAlgebra.diagind(x)) .= scale .* true
56-
return x
57-
end

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1313
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1414
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1515
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
16-
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1716
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1817
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
1918
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/layers/mdeq.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ComponentArrays, DeepEquilibriumNetworks
1+
using DeepEquilibriumNetworks
22
using Test
33

44
include("../test_utils.jl")
@@ -170,7 +170,6 @@ function test_multiscale_neural_ode()
170170
abstol=0.01f0, reltol=0.01f0)
171171

172172
ps, st = Lux.setup(rng, model)
173-
ps = ComponentArray(ps)
174173

175174
@test st.solution === nothing
176175

test/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DeepEquilibriumNetworks, SciMLBase, SteadyStateDiffEq
1+
using DeepEquilibriumNetworks, SteadyStateDiffEq
22
using Test
33

44
simple_dudt(u, p, t) = 0.9f0 .* u .- u

test/utils.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DeepEquilibriumNetworks, LinearAlgebra, Static
1+
using DeepEquilibriumNetworks, LinearAlgebra
22
using Test
33

44
include("test_utils.jl")
@@ -9,8 +9,8 @@ function test_split_and_reshape()
99
x3 = zeros(Float32, 1, 4)
1010

1111
x = vcat(x1, x2, x3)
12-
split_idxs = static(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1))))
13-
shapes = static((size(x1, 1), size(x2, 1), size(x3, 1)))
12+
split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1))))
13+
shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1)))
1414
x_split = DEQs.split_and_reshape(x, split_idxs, shapes)
1515

1616
@test x1 == x_split[1]
@@ -23,17 +23,6 @@ function test_split_and_reshape()
2323
return nothing
2424
end
2525

26-
function test_init_identity_matrix()
27-
x = zeros(Float32, 5, 5, 2)
28-
imat = DEQs.init_identity_matrix(x, 0.5f0)
29-
30-
@test all(diag(imat) .== 0.5f0)
31-
return nothing
32-
end
33-
3426
@testset "split_and_reshape" begin
3527
test_split_and_reshape()
3628
end
37-
@testset "init identity matrix" begin
38-
test_init_identity_matrix()
39-
end

0 commit comments

Comments
 (0)