Skip to content

Commit f8b1878

Browse files
authored
Merge pull request #23 from invenia/wct/dict_to_vec
Dicts, Vectors, and Arrays
2 parents 574136a + fdcac8a commit f8b1878

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

src/grad.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,25 @@ Transform `x` into a `Vector`, and return a closure which inverts the transforma
8282
"""
8383
to_vec(x::Real) = ([x], first)
8484

85-
# Arrays.
85+
# Vectors
8686
to_vec(x::Vector{<:Real}) = (x, identity)
87-
to_vec(x::Array) = vec(x), x_vec->reshape(x_vec, size(x))
87+
function to_vec(x::Vector)
88+
x_vecs_and_backs = map(to_vec, x)
89+
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
90+
return vcat(x_vecs...), function(x_vec)
91+
sz = cumsum([map(length, x_vecs)...])
92+
return [backs[n](x_vec[sz[n]-length(x_vecs[n])+1:sz[n]]) for n in eachindex(x)]
93+
end
94+
end
95+
96+
# Arrays
97+
to_vec(x::Array{<:Real}) = vec(x), x_vec->reshape(x_vec, size(x))
98+
function to_vec(x::Array)
99+
x_vec, back = to_vec(reshape(x, :))
100+
return x_vec, x_vec->reshape(back(x_vec), size(x))
101+
end
88102

89-
# AbstractArrays.
103+
# AbstractArrays
90104
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
91105
x_vec, back = to_vec(Matrix(x))
92106
return x_vec, x_vec->T(reshape(back(x_vec), size(x)))
@@ -99,11 +113,22 @@ function to_vec(X::T) where T<:Union{Adjoint,Transpose}
99113
return vec(Matrix(X)), x_vec->U(permutedims(reshape(x_vec, size(X))))
100114
end
101115

102-
# Non-array data structures.
116+
# Non-array data structures
117+
103118
function to_vec(x::Tuple)
104119
x_vecs, x_backs = zip(map(to_vec, x)...)
105120
sz = cumsum([map(length, x_vecs)...])
106121
return vcat(x_vecs...), function(v)
107122
return ntuple(n->x_backs[n](v[sz[n]-length(x_vecs[n])+1:sz[n]]), length(x))
108123
end
109124
end
125+
126+
# Convert to a vector-of-vectors to make use of existing functionality.
127+
function to_vec(d::Dict)
128+
d_vec_vec = [val for val in values(d)]
129+
d_vec, back = to_vec(d_vec_vec)
130+
return d_vec, function(v)
131+
v_vec_vec = back(v)
132+
return Dict([(key, v_vec_vec[n]) for (n, key) in enumerate(keys(d))])
133+
end
134+
end

test/grad.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ Base.length(x::DummyType) = size(x.X, 1)
5858
test_to_vec(randn(5, 11))
5959
test_to_vec(randn(13, 17, 19))
6060
test_to_vec(randn(13, 0, 19))
61+
test_to_vec([1.0, randn(2), randn(1), 2.0])
62+
test_to_vec([randn(5, 4, 3), (5, 4, 3), 2.0])
63+
test_to_vec(reshape([1.0, randn(5, 4, 3), randn(4, 3), 2.0], 2, 2))
6164
test_to_vec(UpperTriangular(randn(13, 13)))
6265
test_to_vec(Symmetric(randn(11, 11)))
6366
test_to_vec(Diagonal(randn(7)))
@@ -78,6 +81,9 @@ Base.length(x::DummyType) = size(x.X, 1)
7881
test_to_vec((DummyType(randn(2, 7)), DummyType(randn(3, 9))))
7982
test_to_vec((DummyType(randn(3, 2)), randn(11, 8)))
8083
end
84+
@testset "Dictionary" begin
85+
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)))
86+
end
8187
end
8288

8389
@testset "jvp" begin

0 commit comments

Comments
 (0)