@@ -501,6 +501,37 @@ class matrix_cl : public matrix_cl_base {
501
501
*/
502
502
~matrix_cl () { wait_for_read_write_events (); }
503
503
504
+ /* *
505
+ * Set the values of a `matrix_cl` to zero.
506
+ */
507
+ void setZero () {
508
+ if (this ->size () == 0 ) {
509
+ return ;
510
+ }
511
+ cl::Event zero_event;
512
+ this ->wait_for_read_write_events ();
513
+ const std::size_t write_events_size = this ->write_events ().size ();
514
+ const std::size_t read_events_size = this ->read_events ().size ();
515
+ const std::size_t read_write_size = write_events_size + read_events_size;
516
+ std::vector<cl::Event> read_write_events (read_write_size, cl::Event{});
517
+ auto && read_events_vec = this ->read_events ();
518
+ auto && write_events_vec = this ->write_events ();
519
+ for (std::size_t i = 0 ; i < write_events_size; ++i) {
520
+ read_write_events[i] = read_events_vec[i];
521
+ }
522
+ for (std::size_t i = write_events_size, j = 0 ; i < read_write_size; ++i, ++j) {
523
+ read_write_events[i] = write_events_vec[j];
524
+ }
525
+ try {
526
+ opencl_context.queue ().enqueueFillBuffer (
527
+ buffer_cl_, static_cast <T>(0 ), 0 , sizeof (T) * this ->size (),
528
+ &read_write_events, &zero_event);
529
+ } catch (const cl::Error& e) {
530
+ check_opencl_error (" setZero" , e);
531
+ }
532
+ this ->add_write_event (zero_event);
533
+ }
534
+
504
535
private:
505
536
/* *
506
537
* Initializes the OpenCL buffer of this matrix by copying the data from given
0 commit comments