Skip to content

Commit c33c2fa

Browse files
committed
Fix the projection
1 parent e7b50f5 commit c33c2fa

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

src/utils.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ end
2626

2727
function CRC.rrule(::typeof(__flatten_vcat), x)
2828
y = __flatten_vcat(x)
29-
projects = CRC.ProjectTo.(x)
29+
project_x = CRC.ProjectTo(x)
3030
function ∇__flatten_vcat(∂y)
3131
∂y isa CRC.NoTangent && return (CRC.NoTangent(), CRC.NoTangent())
32-
∂x = __split_and_reshape(∂y, x)
33-
∂x = map((∂xᵢ, project) -> project(∂xᵢ), ∂x, projects)
34-
return CRC.NoTangent(), ∂x
32+
return CRC.NoTangent(), project_x(__split_and_reshape(∂y, x))
3533
end
3634
return y, ∇__flatten_vcat
3735
end

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
66
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
910
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
1011
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
1112
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
@@ -21,4 +22,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2122
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2223

2324
[compat]
24-
Aqua = "0.8"
25+
Aqua = "0.8"

test/layers.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,8 @@ end
111111
jacobian_regularizations = (nothing, AutoFiniteDiff(), AutoZygote())
112112

113113
for mtype in model_type, jacobian_regularization in jacobian_regularizations
114-
# @testset "Solver: $(__nameof(solver))"
115-
for solver in solvers
116-
# @testset "x_size: $(x_size)"
117-
for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
114+
@testset "Solver: $(__nameof(solver))" for solver in solvers
115+
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
118116
mapping_layers, init_layers, x_sizes, scales)
119117
@info solver, mtype, jacobian_regularization, main_layer, mapping_layer,
120118
init_layer, x_size, scale

0 commit comments

Comments
 (0)