Skip to content

Commit e773e63

Browse files
author
Avik Pal
committed
Initial version of CUDA working again
1 parent 0e2eb79 commit e773e63

File tree

7 files changed

+35
-19
lines changed

7 files changed

+35
-19
lines changed

Manifest.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ version = "0.2.0"
6161

6262
[[deps.ArrayInterface]]
6363
deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"]
64-
git-tree-sha1 = "247efbccf92448be332d154d6ca56b9fcdd93c31"
64+
git-tree-sha1 = "bbec08a37f8722786d87bedf84eae19c020c4efa"
6565
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
66-
version = "7.6.1"
66+
version = "7.7.0"
6767

6868
[deps.ArrayInterface.extensions]
6969
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
@@ -696,7 +696,7 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
696696

697697
[[deps.LinearSolve]]
698698
deps = ["ArrayInterface", "ConcreteStructs", "DocStringExtensions", "EnumX", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "Libdl", "LinearAlgebra", "MKL_jll", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "StaticArraysCore", "UnPack"]
699-
git-tree-sha1 = "7c9f62dc6f7d11b0f8dfb3c1f3637cb220dc9866"
699+
git-tree-sha1 = "b1148a6596c5cd9b2d848c26b500c79d102ffc5d"
700700
repo-rev = "ap/normal_cholesky_dispatches"
701701
repo-url = "https://github.com/SciML/LinearSolve.jl.git"
702702
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -1248,7 +1248,7 @@ version = "0.3.7"
12481248

12491249
[[deps.SciMLSensitivity]]
12501250
deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "DiffEqBase", "DiffEqCallbacks", "DiffEqNoiseProcess", "Distributions", "EllipsisNotation", "Enzyme", "FiniteDiff", "ForwardDiff", "FunctionProperties", "FunctionWrappersWrappers", "Functors", "GPUArraysCore", "LinearAlgebra", "LinearSolve", "Markdown", "OrdinaryDiffEq", "Parameters", "PreallocationTools", "QuadGK", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "ReverseDiff", "SciMLBase", "SciMLOperators", "SparseDiffTools", "StaticArrays", "StaticArraysCore", "Statistics", "StochasticDiffEq", "Tracker", "TruncatedStacktraces", "Zygote"]
1251-
git-tree-sha1 = "71efe41945d384c3b4873ba681ef8905cae79123"
1251+
git-tree-sha1 = "4120f3ef35508d2f9d967f40364765da4950cefe"
12521252
repo-rev = "ap/ssadjoint_fix"
12531253
repo-url = "https://github.com/SciML/SciMLSensitivity.jl.git"
12541254
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ ps, st = Lux.setup(rng, model) |> gdev
3535
x = rand(rng, Float32, 2, 3) |> gdev
3636
y = rand(rng, Float32, 2, 3) |> gdev
3737
38-
y, st_ = model(x, ps, st)
38+
res, st_ = model(x, ps, st)
3939
st_.layer_2.solution
4040
```
4141

4242
```@example quickstart
43-
gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
43+
gs = only(Zygote.gradient(p -> sum(abs2, first(model(x, p, st)) .- y), ps))
4444
```
4545

4646
## Citation

src/layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ julia> model(x, ps, st);
308308
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
309309
post_fuse_layer::Union{Nothing, Tuple}, solver, scales;
310310
jacobian_regularization=nothing, kwargs...)
311-
@assert jacobian_regularization === nothing "Jacobian Regularization is not supported yet for MultiScale Models."
311+
@assert jacobian_regularization===nothing "Jacobian Regularization is not supported yet for MultiScale Models."
312312
l1 = Parallel(nothing, main_layers...)
313313
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
314314

test/layers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function loss_function(model, x, ps, st)
1111
return l1 + l2 + l3
1212
end
1313

14-
@testset "DeepEquilibriumNetwork" begin
14+
@testset "DeepEquilibriumNetwork: $(mode)" for (mode, aType, dev, ongpu) in MODES
1515
rng = __get_prng(0)
1616

1717
base_models = [
@@ -41,10 +41,10 @@ end
4141
SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
4242
end
4343

44-
ps, st = Lux.setup(rng, model)
44+
ps, st = Lux.setup(rng, model) |> dev
4545
@test st.solution == DeepEquilibriumSolution()
4646

47-
x = randn(rng, Float32, x_size...)
47+
x = randn(rng, Float32, x_size...) |> dev
4848
z, st = model(x, ps, st)
4949

5050
opt_broken = solver isa NewtonRaphson ||
@@ -62,7 +62,7 @@ end
6262
@test __is_finite_gradient(gs_x)
6363
@test __is_finite_gradient(gs_ps)
6464

65-
ps, st = Lux.setup(rng, model)
65+
ps, st = Lux.setup(rng, model) |> dev
6666
st = Lux.update_state(st, :fixed_depth, Val(10))
6767
@test st.solution == DeepEquilibriumSolution()
6868

@@ -82,7 +82,7 @@ end
8282
end
8383
end
8484

85-
@testset "MultiScaleDeepEquilibriumNetworks" begin
85+
@testset "MultiScaleDeepEquilibriumNetwork: $(mode)" for (mode, aType, dev, ongpu) in MODES
8686
rng = __get_prng(0)
8787

8888
main_layers = [

test/runtests.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
using SafeTestsets, Test
22

3-
# TODO: CUDA Testing
4-
const GROUP = get(ENV, "GROUP", "ALL")
5-
63
@testset "Deep Equilibrium Networks" begin
74
@safetestset "Quality Assurance" begin
85
include("qa.jl")

test/test_utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote
22
import LuxTestUtils: @jet
3+
using LuxCUDA
34

45
__nameof(::X) where {X} = nameof(X)
56

@@ -26,3 +27,21 @@ function __get_conv_layer(args...; kwargs...)
2627
init_weight(rng::AbstractRNG, dims...) = randn(rng, Float32, dims) .* 0.001f0
2728
return Conv(args...; init_weight, use_bias=false, kwargs...)
2829
end
30+
31+
const GROUP = get(ENV, "GROUP", "All")
32+
33+
cpu_testing() = GROUP == "All" || GROUP == "CPU"
34+
cuda_testing() = LuxCUDA.functional() && (GROUP == "All" || GROUP == "CUDA")
35+
36+
if !@isdefined(MODES)
37+
const MODES = begin
38+
cpu_mode = ("CPU", Array, LuxCPUDevice(), false)
39+
cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true)
40+
41+
modes = []
42+
cpu_testing() && push!(modes, cpu_mode)
43+
cuda_testing() && push!(modes, cuda_mode)
44+
45+
modes
46+
end
47+
end

test/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ using DeepEquilibriumNetworks, LinearAlgebra, SciMLBase, Test
22

33
include("test_utils.jl")
44

5-
@testset "split_and_reshape" begin
6-
x1 = ones(Float32, 4, 4)
7-
x2 = fill(0.5f0, 2, 4)
8-
x3 = zeros(Float32, 1, 4)
5+
@testset "split_and_reshape: $mode" for (mode, aType, dev, ongpu) in MODES
6+
x1 = ones(Float32, 4, 4) |> aType
7+
x2 = fill(0.5f0, 2, 4) |> aType
8+
x3 = zeros(Float32, 1, 4) |> aType
99

1010
x = vcat(x1, x2, x3)
1111
split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1))))

0 commit comments

Comments
 (0)