Skip to content

Commit b1eb578

Browse files
committed
add setZero() to matrix_cl types
1 parent 0d2a46a commit b1eb578

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

stan/math/opencl/matrix_cl.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,37 @@ class matrix_cl : public matrix_cl_base {
501501
*/
502502
~matrix_cl() { wait_for_read_write_events(); }
503503

504+
/**
505+
* Set the values of a `matrix_cl` to zero.
506+
*/
507+
void setZero() {
508+
if (this->size() == 0) {
509+
return;
510+
}
511+
cl::Event zero_event;
512+
this->wait_for_read_write_events();
513+
const std::size_t write_events_size = this->write_events().size();
514+
const std::size_t read_events_size = this->read_events().size();
515+
const std::size_t read_write_size = write_events_size + read_events_size;
516+
std::vector<cl::Event> read_write_events(read_write_size, cl::Event{});
517+
auto&& read_events_vec = this->read_events();
518+
auto&& write_events_vec = this->write_events();
519+
for (std::size_t i = 0; i < write_events_size; ++i) {
520+
read_write_events[i] = read_events_vec[i];
521+
}
522+
for (std::size_t i = write_events_size, j = 0; i < read_write_size; ++i, ++j) {
523+
read_write_events[i] = write_events_vec[j];
524+
}
525+
try {
526+
opencl_context.queue().enqueueFillBuffer(
527+
buffer_cl_, static_cast<T>(0), 0, sizeof(T) * this->size(),
528+
&read_write_events, &zero_event);
529+
} catch (const cl::Error& e) {
530+
check_opencl_error("setZero", e);
531+
}
532+
this->add_write_event(zero_event);
533+
}
534+
504535
private:
505536
/**
506537
* Initializes the OpenCL buffer of this matrix by copying the data from given

0 commit comments

Comments
 (0)