@@ -184,7 +184,7 @@ mca_coll_han_allreduce_intra(const void *sbuf,
184
184
mca_coll_task_t * t_next_seg = OBJ_NEW (mca_coll_task_t );
185
185
/* Setup up t_next_seg task arguments */
186
186
t -> cur_task = t_next_seg ;
187
- t -> sbuf = (char * ) t -> sbuf + extent * t -> seg_count ;
187
+ t -> sbuf = (t -> sbuf == MPI_IN_PLACE ) ? MPI_IN_PLACE : ( char * ) t -> sbuf + extent * t -> seg_count ;
188
188
t -> rbuf = (char * ) t -> rbuf + extent * t -> seg_count ;
189
189
t -> cur_seg = t -> cur_seg + 1 ;
190
190
/* Init t_next_seg task */
@@ -262,11 +262,26 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
262
262
if (t -> cur_seg == t -> num_segments - 2 && t -> last_seg_count != t -> seg_count ) {
263
263
tmp_count = t -> last_seg_count ;
264
264
}
265
- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
266
- (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
267
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
268
- t -> low_comm -> c_coll -> coll_reduce_module );
269
265
266
+ if (t -> sbuf == MPI_IN_PLACE ) {
267
+ if (!t -> noop ) {
268
+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
269
+ (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
270
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
271
+ t -> low_comm -> c_coll -> coll_reduce_module );
272
+ } else {
273
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + extent * t -> seg_count ,
274
+ NULL , tmp_count ,
275
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
276
+ t -> low_comm -> c_coll -> coll_reduce_module );
277
+
278
+ }
279
+ } else {
280
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
281
+ (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
282
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
283
+ t -> low_comm -> c_coll -> coll_reduce_module );
284
+ }
270
285
}
271
286
if (!t -> noop ) {
272
287
ompi_request_wait (& ireduce_req , MPI_STATUS_IGNORE );
@@ -321,10 +336,26 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
321
336
if (t -> cur_seg == t -> num_segments - 3 && t -> last_seg_count != t -> seg_count ) {
322
337
tmp_count = t -> last_seg_count ;
323
338
}
324
- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
325
- (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
326
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
327
- t -> low_comm -> c_coll -> coll_reduce_module );
339
+
340
+ if (t -> sbuf == MPI_IN_PLACE ) {
341
+ if (!t -> noop ) {
342
+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
343
+ (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
344
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
345
+ t -> low_comm -> c_coll -> coll_reduce_module );
346
+ } else {
347
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 2 * extent * t -> seg_count ,
348
+ NULL , tmp_count ,
349
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
350
+ t -> low_comm -> c_coll -> coll_reduce_module );
351
+
352
+ }
353
+ } else {
354
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
355
+ (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
356
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
357
+ t -> low_comm -> c_coll -> coll_reduce_module );
358
+ }
328
359
}
329
360
if (!t -> noop && req_count > 0 ) {
330
361
ompi_request_wait_all (req_count , reqs , MPI_STATUSES_IGNORE );
@@ -385,10 +416,25 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
385
416
if (t -> cur_seg == t -> num_segments - 4 && t -> last_seg_count != t -> seg_count ) {
386
417
tmp_count = t -> last_seg_count ;
387
418
}
388
- t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
389
- (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
390
- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
391
- t -> low_comm -> c_coll -> coll_reduce_module );
419
+
420
+ if (t -> sbuf == MPI_IN_PLACE ) {
421
+ if (!t -> noop ) {
422
+ t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE ,
423
+ (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
424
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
425
+ t -> low_comm -> c_coll -> coll_reduce_module );
426
+ } else {
427
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf + 3 * extent * t -> seg_count ,
428
+ NULL , tmp_count ,
429
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
430
+ t -> low_comm -> c_coll -> coll_reduce_module );
431
+ }
432
+ } else {
433
+ t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
434
+ (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
435
+ t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
436
+ t -> low_comm -> c_coll -> coll_reduce_module );
437
+ }
392
438
}
393
439
/* lb of cur_seg */
394
440
if (t -> cur_seg == t -> num_segments - 1 && t -> last_seg_count != t -> seg_count ) {
0 commit comments