Skip to content

Commit 70847c8

Browse files
authored
Add names (#58)
1 parent 6a01bc5 commit 70847c8

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed

src/grad.jl

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,57 +118,96 @@ j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1]
118118
119119
Transform `x` into a `Vector`, and return a closure which inverts the transformation.
120120
"""
121-
to_vec(x::Number) = ([x], first)
121+
function to_vec(x::Number)
122+
function Number_from_vec(x_vec)
123+
return first(x_vec)
124+
end
125+
return [x], Number_from_vec
126+
end
122127

123128
# Vectors
124129
to_vec(x::Vector{<:Number}) = (x, identity)
125130
function to_vec(x::Vector)
126131
x_vecs_and_backs = map(to_vec, x)
127132
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
128-
return vcat(x_vecs...), function(x_vec)
133+
function Vector_from_vec(x_vec)
129134
sz = cumsum([map(length, x_vecs)...])
130135
return [backs[n](x_vec[sz[n]-length(x_vecs[n])+1:sz[n]]) for n in eachindex(x)]
131136
end
137+
return vcat(x_vecs...), Vector_from_vec
132138
end
133139

134140
# Arrays
135-
to_vec(x::Array{<:Number}) = vec(x), x_vec->reshape(x_vec, size(x))
141+
function to_vec(x::Array{<:Number})
142+
function Array_from_vec(x_vec)
143+
return reshape(x_vec, size(x))
144+
end
145+
return vec(x), Array_from_vec
146+
end
147+
136148
function to_vec(x::Array)
137149
x_vec, back = to_vec(reshape(x, :))
138-
return x_vec, x_vec->reshape(back(x_vec), size(x))
150+
function Array_from_vec(x_vec)
151+
return reshape(back(x_vec), size(x))
152+
end
153+
return x_vec, Array_from_vec
139154
end
140155

141156
# AbstractArrays
142157
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
143158
x_vec, back = to_vec(Matrix(x))
144-
return x_vec, x_vec->T(reshape(back(x_vec), size(x)))
159+
function AbstractTriangular_from_vec(x_vec)
160+
return T(reshape(back(x_vec), size(x)))
161+
end
162+
return x_vec, AbstractTriangular_from_vec
163+
end
164+
165+
function to_vec(x::Symmetric)
166+
function Symmetric_from_vec(x_vec)
167+
return Symmetric(reshape(x_vec, size(x)))
168+
end
169+
return vec(Matrix(x)), Symmetric_from_vec
170+
end
171+
172+
function to_vec(X::Diagonal)
173+
function Diagonal_from_vec(x_vec)
174+
return Diagonal(reshape(x_vec, size(X)...))
175+
end
176+
return vec(Matrix(X)), Diagonal_from_vec
145177
end
146-
to_vec(x::Symmetric) = vec(Matrix(x)), x_vec->Symmetric(reshape(x_vec, size(x)))
147-
to_vec(X::Diagonal) = vec(Matrix(X)), x_vec->Diagonal(reshape(x_vec, size(X)...))
148178

149179
function to_vec(X::Transpose)
150-
return vec(Matrix(X)), x_vec->Transpose(permutedims(reshape(x_vec, size(X))))
180+
function Transpose_from_vec(x_vec)
181+
return Transpose(permutedims(reshape(x_vec, size(X))))
182+
end
183+
return vec(Matrix(X)), Transpose_from_vec
151184
end
185+
152186
function to_vec(X::Adjoint)
153-
return vec(Matrix(X)), x_vec->Adjoint(conj!(permutedims(reshape(x_vec, size(X)))))
187+
function Adjoint_from_vec(x_vec)
188+
return Adjoint(conj!(permutedims(reshape(x_vec, size(X)))))
189+
end
190+
return vec(Matrix(X)), Adjoint_from_vec
154191
end
155192

156193
# Non-array data structures
157194

158195
function to_vec(x::Tuple)
159196
x_vecs, x_backs = zip(map(to_vec, x)...)
160197
sz = cumsum([map(length, x_vecs)...])
161-
return vcat(x_vecs...), function(v)
198+
function Tuple_from_vec(v)
162199
return ntuple(n->x_backs[n](v[sz[n]-length(x_vecs[n])+1:sz[n]]), length(x))
163200
end
201+
return vcat(x_vecs...), Tuple_from_vec
164202
end
165203

166204
# Convert to a vector-of-vectors to make use of existing functionality.
167205
function to_vec(d::Dict)
168206
d_vec_vec = [val for val in values(d)]
169207
d_vec, back = to_vec(d_vec_vec)
170-
return d_vec, function(v)
208+
function Dict_from_vec(v)
171209
v_vec_vec = back(v)
172210
return Dict([(key, v_vec_vec[n]) for (n, key) in enumerate(keys(d))])
173211
end
212+
return d_vec, Dict_from_vec
174213
end

0 commit comments

Comments
 (0)