Skip to content

Commit f627912

Browse files
authored
Merge pull request #2978 from stan-dev/fix/hard-copy-var-assign
fix assignment for nullptr var_value<matrix> and for assigning expressions
2 parents e43fc08 + a18614f commit f627912

File tree

5 files changed

+194
-6
lines changed

5 files changed

+194
-6
lines changed

stan/math/opencl/matrix_cl.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,37 @@ class matrix_cl : public matrix_cl_base {
501501
*/
502502
~matrix_cl() { wait_for_read_write_events(); }
503503

504+
/**
505+
* Set the values of a `matrix_cl` to zero.
506+
*/
507+
void setZero() {
508+
if (this->size() == 0) {
509+
return;
510+
}
511+
cl::Event zero_event;
512+
const std::size_t write_events_size = this->write_events().size();
513+
const std::size_t read_events_size = this->read_events().size();
514+
const std::size_t read_write_size = write_events_size + read_events_size;
515+
std::vector<cl::Event> read_write_events(read_write_size, cl::Event{});
516+
auto&& read_events_vec = this->read_events();
517+
auto&& write_events_vec = this->write_events();
518+
for (std::size_t i = 0; i < read_events_size; ++i) {
519+
read_write_events[i] = read_events_vec[i];
520+
}
521+
for (std::size_t i = read_events_size, j = 0; j < write_events_size;
522+
++i, ++j) {
523+
read_write_events[i] = write_events_vec[j];
524+
}
525+
try {
526+
opencl_context.queue().enqueueFillBuffer(buffer_cl_, static_cast<T>(0), 0,
527+
sizeof(T) * this->size(),
528+
&read_write_events, &zero_event);
529+
} catch (const cl::Error& e) {
530+
check_opencl_error("setZero", e);
531+
}
532+
this->add_write_event(zero_event);
533+
}
534+
504535
private:
505536
/**
506537
* Initializes the OpenCL buffer of this matrix by copying the data from given

stan/math/rev/core/arena_matrix.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ class arena_matrix : public Eigen::Map<MatrixType> {
128128
Base::operator=(a);
129129
return *this;
130130
}
131+
/**
132+
* Forces hard copying matrices into an arena matrix
133+
* @tparam T Any type assignable to `Base`
134+
* @param x the values to write to `this`
135+
*/
136+
template <typename T>
137+
void deep_copy(const T& x) {
138+
Base::operator=(x);
139+
}
131140
};
132141

133142
} // namespace math

stan/math/rev/core/var.hpp

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,9 +1020,10 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10201020
* @param other the value to assign
10211021
* @return this
10221022
*/
1023-
template <typename S, require_assignable_t<value_type, S>* = nullptr,
1024-
require_all_plain_type_t<T, S>* = nullptr,
1025-
require_not_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
1023+
template <typename S, typename T_ = T,
1024+
require_assignable_t<value_type, S>* = nullptr,
1025+
require_all_plain_type_t<T_, S>* = nullptr,
1026+
require_not_same_t<plain_type_t<T_>, plain_type_t<S>>* = nullptr>
10261027
inline var_value<T>& operator=(const var_value<S>& other) {
10271028
static_assert(
10281029
EIGEN_PREDICATE_SAME_MATRIX_SIZE(T, S),
@@ -1032,16 +1033,65 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10321033
}
10331034

10341035
/**
1035-
* Assignment of another var value, when either this or the other one does not
1036+
* Assignment of another var value, when the `this` does not
10361037
* contain a plain type.
1037-
* @tparam S type of the value in the `var_value` to assing
1038+
* @tparam S type of the value in the `var_value` to assign
1039+
* @param other the value to assign
1040+
* @return this
1041+
*/
1042+
template <typename S, typename T_ = T,
1043+
require_assignable_t<value_type, S>* = nullptr,
1044+
require_not_plain_type_t<S>* = nullptr,
1045+
require_plain_type_t<T_>* = nullptr>
1046+
inline var_value<T>& operator=(const var_value<S>& other) {
1047+
// If vi_ is nullptr then the var needs initialized via copy constructor
1048+
if (!(this->vi_)) {
1049+
*this = var_value<T>(other);
1050+
return *this;
1051+
}
1052+
arena_t<plain_type_t<T>> prev_val(vi_->val_.rows(), vi_->val_.cols());
1053+
prev_val.deep_copy(vi_->val_);
1054+
vi_->val_.deep_copy(other.val());
1055+
// no need to change any adjoints - these are just zeros before the reverse
1056+
// pass
1057+
1058+
reverse_pass_callback(
1059+
[this_vi = this->vi_, other_vi = other.vi_, prev_val]() mutable {
1060+
this_vi->val_.deep_copy(prev_val);
1061+
1062+
// we have no way of detecting aliasing between this->vi_->adj_ and
1063+
// other.vi_->adj_, so we must copy adjoint before reseting to zero
1064+
1065+
// we can reuse prev_val instead of allocating a new matrix
1066+
prev_val.deep_copy(this_vi->adj_);
1067+
this_vi->adj_.setZero();
1068+
other_vi->adj_ += prev_val;
1069+
});
1070+
return *this;
1071+
}
1072+
/**
1073+
* Assignment of another var value, when either both `this` or other does not
1074+
* contain a plain type.
1075+
* @note Here we do not need to use `deep_copy` as the `var_value`'s
1076+
* inner `vari_type` holds a view which will call the assignment operator
1077+
* that does not perform a placement new.
1078+
* @tparam S type of the value in the `var_value` to assign
10381079
* @param other the value to assign
10391080
* @return this
10401081
*/
10411082
template <typename S, typename T_ = T,
10421083
require_assignable_t<value_type, S>* = nullptr,
1043-
require_any_not_plain_type_t<T_, S>* = nullptr>
1084+
require_not_plain_type_t<T_>* = nullptr>
10441085
inline var_value<T>& operator=(const var_value<S>& other) {
1086+
// If vi_ is nullptr then the var needs initialized via copy constructor
1087+
if (!(this->vi_)) {
1088+
[]() STAN_COLD_PATH {
1089+
throw std::domain_error(
1090+
"var_value<matrix>::operator=(var_value<expression>):"
1091+
" Internal Bug! Please report this with an example"
1092+
" of your model to the Stan math github repository.");
1093+
}();
1094+
}
10451095
arena_t<plain_type_t<T>> prev_val = vi_->val_;
10461096
vi_->val_ = other.val();
10471097
// no need to change any adjoints - these are just zeros before the reverse

test/unit/math/opencl/matrix_cl_test.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,18 @@ TEST(MathMatrixCL, assignment) {
7777
EXPECT_EQ(nullptr, mat1_cl.buffer()());
7878
}
7979

80+
TEST(MathMatrixCL, setZeroFun) {
81+
using stan::math::matrix_cl;
82+
Eigen::Matrix<double, 2, 2> mat_1;
83+
mat_1 << 1, 2, 3, 4;
84+
matrix_cl<double> mat1_cl(mat_1);
85+
mat1_cl.setZero();
86+
Eigen::Matrix<double, 2, 2> mat_1_fromcl
87+
= stan::math::from_matrix_cl(mat1_cl);
88+
EXPECT_EQ(mat_1_fromcl(0), 0);
89+
EXPECT_EQ(mat_1_fromcl(1), 0);
90+
EXPECT_EQ(mat_1_fromcl(2), 0);
91+
EXPECT_EQ(mat_1_fromcl(3), 0);
92+
}
93+
8094
#endif

test/unit/math/rev/core/var_test.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,87 @@ TEST_F(AgradRev, matrix_compile_time_conversions) {
910910
EXPECT_MATRIX_FLOAT_EQ(colvec.val(), rowvec.val());
911911
EXPECT_MATRIX_FLOAT_EQ(x11.val(), rowvec.val());
912912
}
913+
914+
TEST_F(AgradRev, assign_nan_varmat) {
915+
using stan::math::var_value;
916+
using var_vector = var_value<Eigen::Matrix<double, -1, 1>>;
917+
using stan::math::var;
918+
Eigen::VectorXd x_val(10);
919+
for (int i = 0; i < 10; ++i) {
920+
x_val(i) = i + 0.1;
921+
}
922+
var_vector x(x_val);
923+
var_vector y = var_vector(Eigen::Matrix<double, -1, 1>::Constant(
924+
10, std::numeric_limits<double>::quiet_NaN()));
925+
y = stan::math::head(x, 10);
926+
var sigma = 1.0;
927+
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
928+
lp.grad();
929+
Eigen::VectorXd x_ans_adj(10);
930+
for (int i = 0; i < 10; ++i) {
931+
x_ans_adj(i) = -(i + 0.1);
932+
}
933+
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
934+
Eigen::VectorXd y_ans_adj = Eigen::VectorXd::Zero(10);
935+
EXPECT_MATRIX_EQ(y_ans_adj, y.adj());
936+
}
937+
938+
TEST_F(AgradRev, assign_nan_matvar) {
939+
using stan::math::var;
940+
using var_vector = Eigen::Matrix<var, -1, 1>;
941+
Eigen::VectorXd x_val(10);
942+
for (int i = 0; i < 10; ++i) {
943+
x_val(i) = i + 0.1;
944+
}
945+
var_vector x(x_val);
946+
var_vector y = var_vector(Eigen::Matrix<double, -1, 1>::Constant(
947+
10, std::numeric_limits<double>::quiet_NaN()));
948+
// need to store y's previous vari pointers
949+
var_vector z = y;
950+
y = stan::math::head(x, 10);
951+
var sigma = 1.0;
952+
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
953+
lp.grad();
954+
Eigen::VectorXd x_ans_adj(10);
955+
for (int i = 0; i < 10; ++i) {
956+
x_ans_adj(i) = -(i + 0.1);
957+
}
958+
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
959+
Eigen::VectorXd z_ans_adj = Eigen::VectorXd::Zero(10);
960+
EXPECT_MATRIX_EQ(z_ans_adj, z.adj());
961+
}
962+
963+
/**
964+
* For var<Matrix> and Matrix<var>, we need to make sure
965+
* the tape, when going through reverse mode, leads to the same outcomes.
966+
* In the case where we declare a var<Matrix> without initializing it, aka
967+
* `var_value<Eigen::MatrixXd>`, we need to think about what the equivalent
968+
* behavior is for `Eigen::Matrix<var, -1, -1>`.
969+
* When default constructing `Eigen::Matrix<var, -1, -1>` we would have an array
970+
* of `var` types with `nullptr` as the vari. The first assignment to that array
971+
* would then just copy the vari pointer from the other array. This is the
972+
* behavior we want to mimic for `var_value<Eigen::MatrixXd>`. So in this test
973+
* show that for uninitialized `var_value<Eigen::MatrixXd>`, we can assign it
974+
* and the adjoints are the same as x.
975+
*/
976+
TEST_F(AgradRev, assign_nullptr_var) {
977+
using stan::math::var_value;
978+
using var_vector = var_value<Eigen::Matrix<double, -1, 1>>;
979+
using stan::math::var;
980+
Eigen::VectorXd x_val(10);
981+
for (int i = 0; i < 10; ++i) {
982+
x_val(i) = i + 0.1;
983+
}
984+
var_vector x(x_val);
985+
var_vector y;
986+
y = stan::math::head(x, 10);
987+
var sigma = 1.0;
988+
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
989+
lp.grad();
990+
Eigen::VectorXd x_ans_adj(10);
991+
for (int i = 0; i < 10; ++i) {
992+
x_ans_adj(i) = -(i + 0.1);
993+
}
994+
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
995+
EXPECT_MATRIX_EQ(x_ans_adj, y.adj());
996+
}

0 commit comments

Comments
 (0)