Skip to content

Commit 103bd6e

Browse files
authored
Add sizes to tape (#7)
* Add sizes computation * Fixes * Fixes * Fix tests * Fix format
1 parent 4584737 commit 103bd6e

File tree

6 files changed

+294
-18
lines changed

6 files changed

+294
-18
lines changed

src/ArrayDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import NaNMath:
5050

5151
include("Coloring/Coloring.jl")
5252
include("graph_tools.jl")
53+
include("sizes.jl")
5354
include("types.jl")
5455
include("utils.jl")
5556

src/reverse_mode.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,44 +109,51 @@ function _forward_eval(
109109
# children, so a backwards pass through f.nodes is a forward pass through
110110
# the tree.
111111
children_arr = SparseArrays.rowvals(f.adj)
112+
fill!(f.partials_storage, zero(T))
112113
for k in length(f.nodes):-1:1
113114
node = f.nodes[k]
114-
f.partials_storage[k] = zero(T)
115+
# Storage index if scalar
116+
j = last(_storage_range(f.sizes, k))
115117
if node.type == Nonlinear.NODE_VARIABLE
116-
f.forward_storage[k] = x[node.index]
118+
f.forward_storage[j] = x[node.index]
117119
# This should never happen, because we will have replaced these by now.
118120
# elseif node.type == Nonlinear.NODE_MOI_VARIABLE
119121
# f.forward_storage[k] = x[node.index]
120122
elseif node.type == Nonlinear.NODE_VALUE
121-
f.forward_storage[k] = f.const_values[node.index]
123+
f.forward_storage[j] = f.const_values[node.index]
122124
elseif node.type == Nonlinear.NODE_SUBEXPRESSION
123-
f.forward_storage[k] = d.subexpression_forward_values[node.index]
125+
f.forward_storage[j] = d.subexpression_forward_values[node.index]
124126
elseif node.type == Nonlinear.NODE_PARAMETER
125-
f.forward_storage[k] = d.data.parameters[node.index]
127+
f.forward_storage[j] = d.data.parameters[node.index]
126128
elseif node.type == Nonlinear.NODE_CALL_MULTIVARIATE
127129
children_indices = SparseArrays.nzrange(f.adj, k)
128130
N = length(children_indices)
129131
# TODO(odow);
130132
# With appropriate benchmarking, the special-cased if-statements can
131133
# be removed in favor of the generic user-defined function case.
132134
if node.index == 1 # :+
133-
tmp_sum = zero(T)
134-
for c_idx in children_indices
135-
@inbounds ix = children_arr[c_idx]
136-
@inbounds f.partials_storage[ix] = one(T)
137-
@inbounds tmp_sum += f.forward_storage[ix]
135+
for j in _eachindex(f.sizes, k)
136+
tmp_sum = zero(T)
137+
for c_idx in children_indices
138+
ix = children_arr[c_idx]
139+
_setindex!(f.partials_storage, one(T), f.sizes, ix, j)
140+
f.partials_storage[ix] = one(T)
141+
tmp_sum += _getindex(f.forward_storage, f.sizes, ix, j)
142+
end
143+
_setindex!(f.forward_storage, tmp_sum, f.sizes, k, j)
138144
end
139-
f.forward_storage[k] = tmp_sum
140145
elseif node.index == 2 # :-
141146
@assert N == 2
142147
child1 = first(children_indices)
143148
@inbounds ix1 = children_arr[child1]
144149
@inbounds ix2 = children_arr[child1+1]
145-
@inbounds tmp_sub = f.forward_storage[ix1]
146-
@inbounds tmp_sub -= f.forward_storage[ix2]
147-
@inbounds f.partials_storage[ix1] = one(T)
148-
@inbounds f.partials_storage[ix2] = -one(T)
149-
f.forward_storage[k] = tmp_sub
150+
for j in _eachindex(f.sizes, k)
151+
tmp_sub = _getindex(f.forward_storage, f.sizes, ix1, j)
152+
tmp_sub -= _getindex(f.forward_storage, f.sizes, ix2, j)
153+
_setindex!(f.partials_storage, one(T), f.sizes, ix1, j)
154+
_setindex!(f.partials_storage, -one(T), f.sizes, ix2, j)
155+
_setindex!(f.forward_storage, tmp_sub, f.sizes, k, j)
156+
end
150157
elseif node.index == 3 # :*
151158
tmp_prod = one(T)
152159
for c_idx in children_indices
@@ -221,7 +228,7 @@ function _forward_eval(
221228
@inbounds f.partials_storage[children_arr[idx1+2]] =
222229
!(condition == 1)
223230
f.forward_storage[k] = ifelse(condition == 1, lhs, rhs)
224-
else
231+
else # atan, min, max or vect
225232
f_input = _UnsafeVectorView(d.jac_storage, N)
226233
∇f = _UnsafeVectorView(d.user_output_buffer, N)
227234
for (r, i) in enumerate(children_indices)

src/sizes.jl

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
struct Sizes
3+
ndims::Vector{Int}
4+
size_offset::Vector{Int}
5+
size::Vector{Int}
6+
storage_offset::Vector{Int}
7+
end
8+
9+
The node at index `k` is an array of `ndims[k]` dimensions and size `sizes[size_offset[k] .+ (1:ndims[k])]`.
10+
Note that `size_offset` is a nonincreasing vector so that `sizes` can be filled in a forward pass,
11+
which goes through the nodes in decreasing index order.
12+
"""
13+
struct Sizes
14+
ndims::Vector{Int}
15+
size_offset::Vector{Int}
16+
size::Vector{Int}
17+
storage_offset::Vector{Int}
18+
end
19+
20+
function _size(sizes::Sizes, k::Int, dim::Int)
21+
return sizes.size[sizes.size_offset[k]+dim]
22+
end
23+
24+
function _size(sizes::Sizes, k::Int)
25+
return view(sizes.size, sizes.size_offset[k] .+ Base.OneTo(sizes.ndims[k]))
26+
end
27+
28+
function _length(sizes::Sizes, k::Int)
29+
if sizes.ndims[k] == 0
30+
return 1
31+
else
32+
return prod(_size(sizes, k))
33+
end
34+
end
35+
36+
_eachindex(sizes::Sizes, k) = Base.OneTo(_length(sizes, k))
37+
38+
_length(sizes::Sizes) = sizes.storage_offset[end]
39+
40+
function _storage_range(sizes::Sizes, k::Int)
41+
return sizes.storage_offset[k] .+ _eachindex(sizes, k)
42+
end
43+
44+
function _getindex(x, sizes::Sizes, k::Int, j)
45+
return x[sizes.storage_offset[k]+j]
46+
end
47+
48+
function _setindex!(x, value, sizes::Sizes, k::Int, j)
49+
return x[sizes.storage_offset[k]+j] = value
50+
end
51+
52+
# /!\ Can only be called in decreasing `k` order
53+
function _add_size!(sizes::Sizes, k::Int, size::Tuple)
54+
sizes.ndims[k] = length(size)
55+
sizes.size_offset[k] = length(sizes.size)
56+
append!(sizes.size, size)
57+
return
58+
end
59+
60+
function _copy_size!(sizes::Sizes, k::Int, child::Int)
61+
sizes.ndims[k] = sizes.ndims[child]
62+
sizes.size_offset[k] = length(sizes.size)
63+
for i in (sizes.size_offset[child] .+ Base.OneTo(sizes.ndims[child]))
64+
push!(sizes.size, sizes.size[i])
65+
end
66+
return
67+
end
68+
69+
function _assert_scalar_children(sizes, children_arr, children_indices, op)
70+
for c_idx in children_indices
71+
@inbounds ix = children_arr[c_idx]
72+
# We don't support nested vectors of vectors,
73+
# we only support real numbers and array of real numbers
74+
@assert sizes.ndims[ix] == 0 "Array argument when expected scalar argument for operator `$op`"
75+
end
76+
end
77+
78+
function _infer_sizes(
79+
nodes::Vector{Nonlinear.Node},
80+
adj::SparseArrays.SparseMatrixCSC{Bool,Int},
81+
)
82+
sizes = Sizes(
83+
zeros(Int, length(nodes)),
84+
zeros(Int, length(nodes)),
85+
Int[],
86+
zeros(Int, length(nodes) + 1),
87+
)
88+
children_arr = SparseArrays.rowvals(adj)
89+
for k in length(nodes):-1:1
90+
node = nodes[k]
91+
children_indices = SparseArrays.nzrange(adj, k)
92+
N = length(children_indices)
93+
if node.type == Nonlinear.NODE_CALL_MULTIVARIATE
94+
if !(
95+
node.index in
96+
eachindex(MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS)
97+
)
98+
# TODO user-defined operators
99+
continue
100+
end
101+
op = MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS[node.index]
102+
if op == :vect
103+
_assert_scalar_children(
104+
sizes,
105+
children_arr,
106+
children_indices,
107+
op,
108+
)
109+
_add_size!(sizes, k, (N,))
110+
elseif op == :dot
111+
# TODO assert all arguments have same size
112+
elseif op == :+ || op == :-
113+
# TODO assert all arguments have same size
114+
_copy_size!(sizes, k, children_arr[first(children_indices)])
115+
elseif op == :*
116+
# TODO assert compatible sizes and all ndims should be 0 or 2
117+
first_matrix = findfirst(children_indices) do i
118+
return !iszero(sizes.ndims[children_arr[i]])
119+
end
120+
if !isnothing(first_matrix)
121+
last_matrix = findfirst(children_indices) do i
122+
return !iszero(sizes.ndims[children_arr[i]])
123+
end
124+
_add_size!(
125+
sizes,
126+
k,
127+
(
128+
_size(sizes, first_matrix, 1),
129+
_size(sizes, last_matrix, sizes.ndims[last_matrix]),
130+
),
131+
)
132+
end
133+
elseif op == :^ || op == :/
134+
@assert N == 2
135+
_assert_scalar_children(
136+
sizes,
137+
children_arr,
138+
children_indices[2:end],
139+
op,
140+
)
141+
_copy_size!(sizes, k, children_arr[first(children_indices)])
142+
else
143+
_assert_scalar_children(
144+
sizes,
145+
children_arr,
146+
children_indices,
147+
op,
148+
)
149+
end
150+
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
151+
if !(
152+
node.index in
153+
eachindex(MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS)
154+
)
155+
# TODO user-defined operators
156+
continue
157+
end
158+
@assert N == 1
159+
op = MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS[node.index]
160+
if op == :+ || op == :-
161+
_copy_size!(sizes, k, children_arr[first(children_indices)])
162+
else
163+
_assert_scalar_children(
164+
sizes,
165+
children_arr,
166+
children_indices,
167+
op,
168+
)
169+
end
170+
end
171+
end
172+
for k in eachindex(nodes)
173+
sizes.storage_offset[k+1] = sizes.storage_offset[k] + _length(sizes, k)
174+
end
175+
return sizes
176+
end

src/types.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
struct _SubexpressionStorage
88
nodes::Vector{Nonlinear.Node}
99
adj::SparseArrays.SparseMatrixCSC{Bool,Int}
10+
sizes::Sizes
1011
const_values::Vector{Float64}
1112
forward_storage::Vector{Float64}
1213
partials_storage::Vector{Float64}
@@ -21,10 +22,12 @@ struct _SubexpressionStorage
2122
partials_storage_ϵ::Vector{Float64},
2223
linearity::Linearity,
2324
)
24-
N = length(nodes)
25+
sizes = _infer_sizes(nodes, adj)
26+
N = _length(sizes)
2527
return new(
2628
nodes,
2729
adj,
30+
_infer_sizes(nodes, adj),
2831
const_values,
2932
zeros(N), # forward_storage,
3033
zeros(N), # partials_storage,

test/ArrayDiff.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
module TestArrayDiff
2+
3+
using Test
4+
import LinearAlgebra
5+
import SparseArrays
6+
7+
import MathOptInterface as MOI
8+
const Nonlinear = MOI.Nonlinear
9+
10+
import ArrayDiff
11+
const Coloring = ArrayDiff.Coloring
12+
13+
function runtests()
14+
for name in names(@__MODULE__; all = true)
15+
if startswith("$(name)", "test_")
16+
@testset "$(name)" begin
17+
getfield(@__MODULE__, name)()
18+
end
19+
end
20+
end
21+
return
22+
end
23+
24+
function test_objective_quadratic_univariate()
25+
x = MOI.VariableIndex(1)
26+
scalar = Nonlinear.Model()
27+
Nonlinear.set_objective(model, :($x * $x))
28+
vector = Nonlinear.Model()
29+
Nonlinear.set_objective(vector, :([$x] * [$x]))
30+
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x])
31+
MOI.initialize(evaluator, [:Grad, :Jac, :Hess])
32+
@test MOI.eval_objective(evaluator, [1.2]) == 1.2^2 + 1
33+
g = [NaN]
34+
MOI.eval_objective_gradient(evaluator, g, [1.2])
35+
@test g == [2.4]
36+
@test MOI.hessian_objective_structure(evaluator) == [(1, 1)]
37+
H = [NaN]
38+
MOI.eval_hessian_objective(evaluator, H, [1.2])
39+
@test H == [2.0]
40+
@test MOI.hessian_lagrangian_structure(evaluator) == [(1, 1)]
41+
H = [NaN]
42+
MOI.eval_hessian_lagrangian(evaluator, H, [1.2], 1.5, Float64[])
43+
@test H == 1.5 .* [2.0]
44+
MOI.eval_hessian_lagrangian_product(
45+
evaluator,
46+
H,
47+
[1.2],
48+
[1.2],
49+
1.5,
50+
Float64[],
51+
)
52+
@test H == [1.5 * 2.0 * 1.2]
53+
return
54+
end
55+
56+
end # module
57+
58+
TestArrayDiff.runtests()
59+
60+
import MathOptInterface as MOI
61+
import ArrayDiff
62+
using Test
63+
const Nonlinear = MOI.Nonlinear
64+
model = Nonlinear.Model()
65+
x = MOI.VariableIndex(1)
66+
Nonlinear.set_objective(model, :(dot([$x], [$x])))
67+
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x])
68+
MOI.initialize(evaluator, [:Grad])
69+
sizes = evaluator.backend.objective.expr.sizes
70+
@test MOI.eval_objective(evaluator, [1.2]) == 1.2^2
71+
@test sizes.ndims == [0, 1, 0, 1, 0]
72+
@test sizes.size_offset == [0, 1, 0, 0, 0]
73+
@test sizes.size == [1, 1]
74+
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
75+
76+
y = MOI.VariableIndex(1)
77+
Nonlinear.set_objective(model, :(dot([$x, $y] - [1, 2], -[1, 2] + [$x, $y])))
78+
MOI.initialize(evaluator, [:Grad])
79+
sizes = evaluator.backend.objective.expr.sizes
80+
@test MOI.eval_objective(evaluator, [1.2]) == 1.2^2
81+
@test sizes.ndims == [0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0]
82+
@test sizes.size_offset == [0, 6, 5, 0, 0, 4, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0]
83+
@test sizes.size == [2, 2, 2, 2, 2, 2, 2]
84+
@test sizes.storage_offset ==
85+
[0, 1, 3, 5, 6, 7, 9, 10, 11, 13, 15, 17, 18, 19, 21, 22, 23]
86+
g = [NaN]
87+
MOI.eval_objective_gradient(evaluator, g, [1.2])

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
include("ReverseAD.jl")
2+
3+
import SparseArrays

0 commit comments

Comments
 (0)