@@ -354,7 +354,7 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
354
354
if (t -> cur_seg == t -> num_segments - 2 && t -> last_seg_count != t -> seg_count ) {
355
355
tmp_count = t -> last_seg_count ;
356
356
}
357
- t -> up_comm -> c_coll -> coll_ibcast ((char * ) t -> rbuf + extent * t -> seg_count , t -> seg_count ,
357
+ t -> up_comm -> c_coll -> coll_ibcast ((char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
358
358
t -> dtype , t -> root_up_rank , t -> up_comm , & (reqs [0 ]),
359
359
t -> up_comm -> c_coll -> coll_ibcast_module );
360
360
req_count ++ ;
@@ -391,7 +391,13 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
391
391
t -> low_comm -> c_coll -> coll_reduce_module );
392
392
}
393
393
/* lb of cur_seg */
394
- t -> low_comm -> c_coll -> coll_bcast ((char * ) t -> rbuf , t -> seg_count , t -> dtype , t -> root_low_rank ,
394
+ if (t -> cur_seg == t -> num_segments - 1 && t -> last_seg_count != t -> seg_count ) {
395
+ tmp_count = t -> last_seg_count ;
396
+ } else {
397
+ tmp_count = t -> seg_count ;
398
+ }
399
+
400
+ t -> low_comm -> c_coll -> coll_bcast ((char * ) t -> rbuf , tmp_count , t -> dtype , t -> root_low_rank ,
395
401
t -> low_comm , t -> low_comm -> c_coll -> coll_bcast_module );
396
402
if (!t -> noop && req_count > 0 ) {
397
403
ompi_request_wait_all (req_count , reqs , MPI_STATUSES_IGNORE );
0 commit comments