Skip to content

Commit 8449c99

Browse files
committed
change hard_copy to deep_copy, fix async logic in setZero() for matrix_cl, and add test for nan var matrix behaviors
1 parent b1eb578 commit 8449c99

File tree

4 files changed

+36
-9
lines changed

4 files changed

+36
-9
lines changed

stan/math/opencl/matrix_cl.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,6 @@ class matrix_cl : public matrix_cl_base {
509509
return;
510510
}
511511
cl::Event zero_event;
512-
this->wait_for_read_write_events();
513512
const std::size_t write_events_size = this->write_events().size();
514513
const std::size_t read_events_size = this->read_events().size();
515514
const std::size_t read_write_size = write_events_size + read_events_size;

stan/math/rev/core/arena_matrix.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class arena_matrix : public Eigen::Map<MatrixType> {
134134
* @param x the values to write to `this`
135135
*/
136136
template <typename T>
137-
void hard_copy(const T& x) {
137+
void deep_copy(const T& x) {
138138
Base::operator=(x);
139139
}
140140
};

stan/math/rev/core/var.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,20 +1052,20 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10521052
return *this;
10531053
}
10541054
arena_t<plain_type_t<T>> prev_val(vi_->val_.rows(), vi_->val_.cols());
1055-
prev_val.hard_copy(vi_->val_);
1056-
vi_->val_.hard_copy(other.val());
1055+
prev_val.deep_copy(vi_->val_);
1056+
vi_->val_.deep_copy(other.val());
10571057
// no need to change any adjoints - these are just zeros before the reverse
10581058
// pass
10591059

10601060
reverse_pass_callback(
10611061
[this_vi = this->vi_, other_vi = other.vi_, prev_val]() mutable {
1062-
this_vi->val_.hard_copy(prev_val);
1062+
this_vi->val_.deep_copy(prev_val);
10631063

10641064
// we have no way of detecting aliasing between this->vi_->adj_ and
10651065
// other.vi_->adj_, so we must copy adjoint before reseting to zero
10661066

10671067
// we can reuse prev_val instead of allocating a new matrix
1068-
prev_val.hard_copy(this_vi->adj_);
1068+
prev_val.deep_copy(this_vi->adj_);
10691069
this_vi->adj_.setZero();
10701070
other_vi->adj_ += prev_val;
10711071
});
@@ -1074,13 +1074,15 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10741074
/**
10751075
* Assignment of another var value, when either both `this` or other does not
10761076
* contain a plain type.
1077+
* @note Here we do not need to use `deep_copy` as the `var_value`'s
1078+
* inner `vari_type` holds a view which will call the assignment operator
1079+
* that does not perform a placement new.
10771080
* @tparam S type of the value in the `var_value` to assign
10781081
* @param other the value to assign
10791082
* @return this
10801083
*/
10811084
template <typename S, typename T_ = T,
10821085
require_assignable_t<value_type, S>* = nullptr,
1083-
require_any_not_plain_type_t<T_, S>* = nullptr,
10841086
require_not_plain_type_t<T_>* = nullptr>
10851087
inline var_value<T>& operator=(const var_value<S>& other) {
10861088
// If vi_ is nullptr then the var needs initialized via copy constructor
@@ -1105,7 +1107,7 @@ class var_value<T, internal::require_matrix_var_value<T>> {
11051107
// other.vi_->adj_, so we must copy adjoint before reseting to zero
11061108

11071109
// we can reuse prev_val instead of allocating a new matrix
1108-
prev_val.hard_copy(this_vi->adj_);
1110+
prev_val = this_vi->adj_;
11091111
this_vi->adj_.setZero();
11101112
other_vi->adj_ += prev_val;
11111113
});

test/unit/math/rev/core/var_test.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ TEST_F(AgradRev, matrix_compile_time_conversions) {
911911
EXPECT_MATRIX_FLOAT_EQ(x11.val(), rowvec.val());
912912
}
913913

914-
TEST_F(AgradRev, assign_nan) {
914+
TEST_F(AgradRev, assign_nan_varmat) {
915915
using stan::math::var_value;
916916
using var_vector = var_value<Eigen::Matrix<double, -1, 1>>;
917917
using stan::math::var;
@@ -935,6 +935,32 @@ TEST_F(AgradRev, assign_nan) {
935935
EXPECT_MATRIX_EQ(y_ans_adj, y.adj());
936936
}
937937

938+
TEST_F(AgradRev, assign_nan_matvar) {
939+
using stan::math::var;
940+
using var_vector = Eigen::Matrix<var, -1, 1>;
941+
Eigen::VectorXd x_val(10);
942+
for (int i = 0; i < 10; ++i) {
943+
x_val(i) = i + 0.1;
944+
}
945+
var_vector x(x_val);
946+
var_vector y = var_vector(Eigen::Matrix<double, -1, 1>::Constant(
947+
10, std::numeric_limits<double>::quiet_NaN()));
948+
// need to store y's previous vari pointers
949+
var_vector z = y;
950+
y = stan::math::head(x, 10);
951+
var sigma = 1.0;
952+
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
953+
lp.grad();
954+
Eigen::VectorXd x_ans_adj(10);
955+
for (int i = 0; i < 10; ++i) {
956+
x_ans_adj(i) = -(i + 0.1);
957+
}
958+
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
959+
Eigen::VectorXd z_ans_adj = Eigen::VectorXd::Zero(10);
960+
EXPECT_MATRIX_EQ(z_ans_adj, z.adj());
961+
}
962+
963+
938964
TEST_F(AgradRev, assign_nullptr_vari) {
939965
using stan::math::var_value;
940966
using var_vector = var_value<Eigen::Matrix<double, -1, 1>>;

0 commit comments

Comments
 (0)