Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.

Commit b542555

Browse files
Small clarifications and improvements (#1611)
Summary: Pull Request resolved: #1611 Some small improvements and clarifications (mostly to documentation) while in the process of working on something else; committing them separately. Reviewed By: gafter Differential Revision: D36923195 fbshipit-source-id: 515951ca5b9f33626bd5d135ed45dcf7d8d5002f
1 parent 2c80b0a commit b542555

File tree

7 files changed

+73
-34
lines changed

7 files changed

+73
-34
lines changed

src/beanmachine/graph/distribution/distribution.h

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,13 @@ class Distribution : public graph::Node {
112112
the log prob of the distribution w.r.t. the sampled value.
113113
:param value: value of the child Sample operator, a single draw from the
114114
distribution
115-
:param back_grad: back_grad1 of the child Sample operator, to be incremented
116-
:param adjunct: a multiplier that represents the gradient of the target
117-
function w.r.t the log prob of this distribution. It uses the default value
118-
1.0 if the direct child is a StochasticOperator, but requires input if the
119-
direct child is a mixture distribution.
115+
:param back_grad: variable to which the gradient will be added.
116+
:param adjunct: if we are interested in df(log_prob)/dvalue, then
117+
adjunct must be df(log_prob)/dlog_prob.
118+
If we are interested in dlog_prob/dvalue then the adjunct is 1,
119+
which is the default.
120+
For other cases (such as this distribution being a component of a
121+
mixture distribution), the appropriate adjunct must be provided.
120122
*/
121123
virtual void backward_value(
122124
const graph::NodeValue& /* value */,
@@ -134,14 +136,8 @@ class Distribution : public graph::Node {
134136
graph::DoubleMatrix& /* back_grad */,
135137
Eigen::MatrixXd& /* adjunct */) const {}
136138
/*
137-
In backward gradient propagation, increments the back_grad1 of each parent
138-
node w.r.t. the log prob of the distribution, evaluated at the given value.
139-
:param value: value of the child Sample operator, a single draw from the
140-
distribution
141-
:param adjunct: a multiplier that represents the gradient of the
142-
target function w.r.t the log prob of this distribution. It uses the default
143-
value 1.0 if the direct child is a StochasticOperator, but requires input if
144-
the direct child is a mixture distribution.
139+
Analogous to backward_value, but computes the gradient
140+
wrt to each parameter, and adds results to their back_grad1 field.
145141
*/
146142
virtual void backward_param(
147143
const graph::NodeValue& /* value */,

src/beanmachine/graph/graph.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
namespace beanmachine {
2424
namespace graph {
2525

26+
NATURAL_TYPE NATURAL_ZERO = 0ull;
27+
NATURAL_TYPE NATURAL_ONE = 1ull;
28+
2629
std::string ValueType::to_string() const {
2730
std::string vtype;
2831
std::string atype;

src/beanmachine/graph/graph.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ struct ValueType {
124124

125125
typedef NATURAL_TYPE natural_t;
126126

127+
extern NATURAL_TYPE NATURAL_ZERO;
128+
extern NATURAL_TYPE NATURAL_ONE;
129+
127130
class NodeValue {
128131
public:
129132
ValueType type;
@@ -374,6 +377,9 @@ enum class DistributionType {
374377
LKJ_CHOLESKY
375378
};
376379

380+
// TODO: do we really need DistributionType? Can't we know the type of a
381+
// Distribution from its class alone?
382+
377383
enum class FactorType {
378384
UNKNOWN,
379385
EXP_PRODUCT,
@@ -449,9 +455,17 @@ class Node {
449455
virtual bool needs_gradient() const {
450456
return true;
451457
}
452-
// gradient_log_prob is also only valid for stochastic nodes
453-
// TODO: shouldn't we then restrict them to those classes? See above.
454-
// this function adds the gradients to the passed in gradients
458+
// gradient_log_prob is also only valid for stochastic nodes.
459+
// (TODO: shouldn't we then restrict them to those classes? See above.)
460+
// It computes the first and second gradients of the log prob
461+
// of this node with respect to a given target node and
462+
// adds them to the passed-in gradient parameters.
463+
// Note that for this computation to be correct,
464+
// gradients (the grad1 and grad2 properties of nodes)
465+
// must have been updated all the way from the
466+
// target node to this node.
467+
// This is because this method only performs a local computation
468+
// and relies on the grad1 and grad2 attributes of nodes.
455469
virtual void gradient_log_prob(
456470
const graph::Node* target_node,
457471
double& /* grad1 */,

src/beanmachine/graph/marginalization/marginalized_graph.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ all of the nodes required to compute the MarginalDistribution
6363
4. the stochastic children nodes of the discrete sample
6464
5. the parents (a node not in 1-4 that has a child in 1-4)
6565
66-
The original graph will contain
66+
The original graph will be modified to contain
6767
1. the MarginalDistribution node (to replace #1-3 from the
6868
subgraph above)
69-
2. the children of the MarginalDistribution are the
69+
2. the children of the MarginalDistribution, which are the
7070
stochastic children nodes of the discrete node
7171
(the same as #4 from the subgraph)
72-
3. the parents of the MarginalDistribution are the parents
72+
3. the parents of the MarginalDistribution, which are the parents
7373
of the subgraph (same as #5 from the subgraph)
7474
7575
In order to keep the original graph and the subgraph completely
@@ -87,7 +87,8 @@ same as the parent node in the graph.
8787
CHILDREN:
8888
The children of the MarginalDistribution are the stochastic children of
8989
the discrete sample node.
90-
The stochastic children are needed to compute the MarginalDistribution,
90+
The stochastic children are needed to compute the
91+
log prob of the MarginalDistribution,
9192
so they are part of the subgraph.
9293
However, a "copy" of these children also needs to be added to the graph.
9394
This "copy" node is a SAMPLE node of MarginalDistribution whose value
@@ -104,11 +105,13 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) {
104105
std::vector<uint> sto_node_ids;
105106
std::tie(det_node_ids, sto_node_ids) =
106107
compute_children(graph, discrete_sample->index);
108+
// TODO: do we need to rename the above compute_affected_nodes,
109+
// or even use Graph's methods for that instead of computing it ourselves?
107110

108111
// create MarginalDistribution
109112
std::unique_ptr<distribution::DummyMarginal> marginal_distribution_ptr =
110113
std::make_unique<distribution::DummyMarginal>(std::move(subgraph_ptr));
111-
// TODO: support the correct sample type for multiple children
114+
// TODO: support multiple children
112115
if (sto_node_ids.size() > 0) {
113116
// @lint-ignore
114117
marginal_distribution_ptr->sample_type =
@@ -119,17 +122,15 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) {
119122
marginal_distribution_ptr.get();
120123
SubGraph* subgraph = marginal_distribution->subgraph_ptr.get();
121124

122-
// add nodes to subgraph
123125
add_nodes_to_subgraph(
124126
subgraph,
125127
discrete_distribution,
126128
discrete_sample,
127129
det_node_ids,
128130
sto_node_ids);
129131

130-
// connect parents to MarginalDistribution in graph
131132
connect_parents_to_marginal_distribution(graph, marginal_distribution);
132-
// add copy of parents to subgraph
133+
133134
add_copy_of_parent_nodes_to_subgraph(subgraph, marginal_distribution);
134135

135136
// create and connect children to MarginalDistribution
@@ -139,6 +140,7 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) {
139140

140141
// list of all created nodes to add to `nodes` of current graph
141142
std::vector<std::unique_ptr<Node>> created_graph_nodes;
143+
142144
// add MarginalDistribution to list of created nodes
143145
created_graph_nodes.push_back(std::move(marginal_distribution_ptr));
144146
// add created nodes to list of created_graph_nodes
@@ -154,6 +156,7 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) {
154156
// the created nodes should be inserted right after the largest parent index
155157
uint marginal_distribution_index =
156158
compute_largest_parent_index(marginal_distribution) + 1;
159+
157160
// insert created nodes into graph at "marginalized_node_index"
158161
for (uint i = 0; i < created_graph_nodes.size(); i++) {
159162
graph.nodes.insert(
@@ -164,7 +167,7 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) {
164167
}
165168

166169
/*
167-
returns <determinisitc_node_ids, stochastic_node_ids>
170+
returns <deterministic_node_ids, stochastic_node_ids>
168171
1. deterministic_node_ids are all of the deterministic nodes up until the
169172
2. stochastic_node_ids children are reached
170173
*/

src/beanmachine/graph/operator/stochasticop.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ void StochasticOperator::gradient_log_prob(
4242
// but it is not represented as an in-node.
4343
//
4444
// This makes the computation of this derivative less uniform
45-
// and less directly corresponding to the typical use of the chain rule.
45+
// and less directly corresponding to the typical use of the chain rule
46+
// for the operations explicitly represented as nodes.
4647
//
4748
// Still, it should be possible to simply apply the chain rule
4849
// and find an expression involving the gradient of this stochastic node's

src/beanmachine/graph/util.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,10 @@ double Phi_approx_inv(double z) {
7676
}
7777

7878
double log_sum_exp(const std::vector<double>& values) {
79-
// find the max and subtract it out
80-
double max = values[0];
81-
for (std::vector<double>::size_type idx = 1; idx < values.size(); idx++) {
82-
if (values[idx] > max) {
83-
max = values[idx];
84-
}
85-
}
79+
// See "log-sum-exp trick for log-domain calculations" in
80+
// https://en.wikipedia.org/wiki/LogSumExp
81+
assert(values.size() != 0);
82+
double max = *std::max_element(values.begin(), values.end());
8683
double sum = 0;
8784
for (auto value : values) {
8885
sum += std::exp(value - max);
@@ -96,6 +93,21 @@ double log_sum_exp(double a, double b) {
9693
return std::log(sum) + max_val;
9794
}
9895

96+
std::vector<double> probs_given_log_potentials(std::vector<double> log_pot) {
97+
// p_i = pot_i/Z
98+
// where Z is the normalization constant sum_i exp(log pot_i).
99+
// = exp(log(pot_i/Z))
100+
// = exp(log pot_i - logZ)
101+
// logZ is log(sum_i exp(log pot_i))
102+
auto logZ = log_sum_exp(log_pot);
103+
std::vector<double> probs;
104+
probs.reserve(log_pot.size());
105+
for (size_t i = 0; i != log_pot.size(); i++) {
106+
probs.push_back(std::exp(log_pot[i] - logZ));
107+
}
108+
return probs;
109+
}
110+
99111
double polygamma(int n, double x) {
100112
return boost::math::polygamma(n, x);
101113
}

src/beanmachine/graph/util.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,23 @@ std::vector<T> percentiles(
8989
}
9090

9191
/*
92-
Compute log of the sum of the exponentiation of all the values in the vector
92+
Equivalent to log of sum of exponentiations of values,
93+
but more numerically stable.
9394
:param values: vector of log values
9495
:returns: log sum exp of values
9596
*/
9697
double log_sum_exp(const std::vector<double>& values);
9798
double log_sum_exp(double a, double b);
9899

100+
/*
101+
Given log potentials log pot_i
102+
where potentials pot_i are an unnormalized probability distribution,
103+
return the normalized probability distribution p_i.
104+
p_i = pot_i/Z
105+
where Z is the normalization constant sum_i exp(log pot_i).
106+
*/
107+
std::vector<double> probs_given_log_potentials(std::vector<double> log_pot);
108+
99109
struct BinaryLogSumExp {
100110
double operator()(double a, double b) const {
101111
return log_sum_exp(a, b);

0 commit comments

Comments
 (0)