Skip to content

Commit 3700773

Browse files
author
Avik Pal
committed
Incorrect trace
1 parent e773e63 commit 3700773

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

Manifest.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,9 @@ version = "6.144.0"
271271

272272
[[deps.DiffEqCallbacks]]
273273
deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"]
274-
git-tree-sha1 = "d0b94b3694d55e7eedeee918e7daee9e3b873399"
274+
git-tree-sha1 = "e48b985459d1cbe8c809de192529f1e25c3382a6"
275275
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
276-
version = "2.35.0"
276+
version = "2.36.0"
277277

278278
[deps.DiffEqCallbacks.weakdeps]
279279
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

src/utils.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
::Val{shapes}) where {idxs, shapes}
33
dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)]
44
varnames = map(_ -> gensym("x_view"), dims)
5-
calls = [:($(varnames[i]) = view(x, $(dims[i]), :)) for i in 1:length(dims)]
5+
calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in 1:length(dims)]
66
return quote
77
$(calls...)
88
return tuple($(varnames...))
@@ -14,7 +14,8 @@ __split_and_reshape(x::AbstractArray, ::Nothing, ::Nothing) = x
1414
function __split_and_reshape(y::AbstractMatrix, x)
1515
szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x]
1616
counters = vcat(0, cumsum(szs)[1:(end - 1)])
17-
return map((sz, c, xᵢ) -> reshape(view(y, (c + 1):(c + sz), :), size(xᵢ)),
17+
# Make the data contiguous
18+
return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))),
1819
szs, counters, x)
1920
end
2021

@@ -95,7 +96,7 @@ CRC.@non_differentiable __gaussian_like(::Any...)
9596

9697
# Jacobian Stabilization
9798
function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng)
98-
__f = u -> first(model((u, x), ps))
99+
__f = u -> model((u, x), ps)
99100
res = zero(eltype(x))
100101
ϵ = cbrt(eps(typeof(res)))
101102
ϵ⁻¹ = inv(ϵ)

test/layers.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ end
2323

2424
model_type = (:deq, :skipdeq, :skipregdeq)
2525
solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden())
26-
jacobian_regularizations = (nothing, AutoFiniteDiff(), AutoZygote())
26+
jacobian_regularizations = Any[nothing, AutoZygote()]
27+
!ongpu && push!(jacobian_regularizations, AutoFiniteDiff())
2728

2829
@testset "Solver: $(__nameof(solver))" for solver in solvers,
2930
mtype in model_type, jacobian_regularization in jacobian_regularizations
@@ -133,10 +134,10 @@ end
133134
jacobian_regularization)
134135
end
135136

136-
ps, st = Lux.setup(rng, model)
137+
ps, st = Lux.setup(rng, model) |> dev
137138
@test st.solution == DeepEquilibriumSolution()
138139

139-
x = randn(rng, Float32, x_size...)
140+
x = randn(rng, Float32, x_size...) |> dev
140141
z, st = model(x, ps, st)
141142
z_ = DEQs.__flatten_vcat(z)
142143

@@ -157,7 +158,7 @@ end
157158
@test __is_finite_gradient(gs_x)
158159
@test __is_finite_gradient(gs_ps)
159160

160-
ps, st = Lux.setup(rng, model)
161+
ps, st = Lux.setup(rng, model) |> dev
161162
st = Lux.update_state(st, :fixed_depth, Val(10))
162163
@test st.solution == DeepEquilibriumSolution()
163164

test/test_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote
22
import LuxTestUtils: @jet
33
using LuxCUDA
44

5+
CUDA.allowscalar(false)
6+
57
__nameof(::X) where {X} = nameof(X)
68

79
__get_prng(seed::Int) = StableRNG(seed)

0 commit comments

Comments
 (0)