@@ -760,3 +760,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
760
760
return rc ;
761
761
}
762
762
#endif
763
+
764
+ #if HCOLL_API > HCOLL_VERSION (4 ,5 )
765
+ int mca_coll_hcoll_reduce_scatter_block (const void * sbuf , void * rbuf , int rcount ,
766
+ struct ompi_datatype_t * dtype ,
767
+ struct ompi_op_t * op ,
768
+ struct ompi_communicator_t * comm ,
769
+ mca_coll_base_module_t * module ) {
770
+ dte_data_representation_t Dtype ;
771
+ hcoll_dte_op_t * Op ;
772
+ int rc ;
773
+ HCOL_VERBOSE (20 ,"RUNNING HCOL REDUCE SCATTER BLOCK" );
774
+ mca_coll_hcoll_module_t * hcoll_module = (mca_coll_hcoll_module_t * )module ;
775
+ Dtype = ompi_dtype_2_hcoll_dtype (dtype , NO_DERIVED );
776
+ if (OPAL_UNLIKELY (HCOL_DTE_IS_ZERO (Dtype ))){
777
+ /*If we are here then datatype is not simple predefined datatype */
778
+ /*In future we need to add more complex mapping to the dte_data_representation_t */
779
+ /* Now use fallback */
780
+ HCOL_VERBOSE (20 ,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;" ,
781
+ dtype -> super .name );
782
+ goto fallback ;
783
+ }
784
+
785
+ Op = ompi_op_2_hcolrte_op (op );
786
+ if (OPAL_UNLIKELY (HCOL_DTE_OP_NULL == Op -> id )){
787
+ /*If we are here then datatype is not simple predefined datatype */
788
+ /*In future we need to add more complex mapping to the dte_data_representation_t */
789
+ /* Now use fallback */
790
+ HCOL_VERBOSE (20 ,"ompi_op_t is not supported: op = %s; calling fallback allreduce;" ,
791
+ op -> o_name );
792
+ goto fallback ;
793
+ }
794
+
795
+ rc = hcoll_collectives .coll_reduce_scatter_block ((void * )sbuf ,rbuf ,rcount ,Dtype ,Op ,hcoll_module -> hcoll_context );
796
+ if (HCOLL_SUCCESS != rc ){
797
+ fallback :
798
+ HCOL_VERBOSE (20 ,"RUNNING FALLBACK ALLREDUCE" );
799
+ rc = hcoll_module -> previous_reduce_scatter_block (sbuf ,rbuf ,
800
+ rcount ,dtype ,op ,
801
+ comm , hcoll_module -> previous_allreduce_module );
802
+ }
803
+ return rc ;
804
+ }
805
+
806
+ int mca_coll_hcoll_reduce_scatter (const void * sbuf , void * rbuf , const int * rcounts ,
807
+ struct ompi_datatype_t * dtype ,
808
+ struct ompi_op_t * op ,
809
+ struct ompi_communicator_t * comm ,
810
+ mca_coll_base_module_t * module ) {
811
+ dte_data_representation_t Dtype ;
812
+ hcoll_dte_op_t * Op ;
813
+ int rc ;
814
+ HCOL_VERBOSE (20 ,"RUNNING HCOL REDUCE SCATTER" );
815
+ mca_coll_hcoll_module_t * hcoll_module = (mca_coll_hcoll_module_t * )module ;
816
+ Dtype = ompi_dtype_2_hcoll_dtype (dtype , NO_DERIVED );
817
+ if (OPAL_UNLIKELY (HCOL_DTE_IS_ZERO (Dtype ))){
818
+ /*If we are here then datatype is not simple predefined datatype */
819
+ /*In future we need to add more complex mapping to the dte_data_representation_t */
820
+ /* Now use fallback */
821
+ HCOL_VERBOSE (20 ,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;" ,
822
+ dtype -> super .name );
823
+ goto fallback ;
824
+ }
825
+
826
+ Op = ompi_op_2_hcolrte_op (op );
827
+ if (OPAL_UNLIKELY (HCOL_DTE_OP_NULL == Op -> id )){
828
+ /*If we are here then datatype is not simple predefined datatype */
829
+ /*In future we need to add more complex mapping to the dte_data_representation_t */
830
+ /* Now use fallback */
831
+ HCOL_VERBOSE (20 ,"ompi_op_t is not supported: op = %s; calling fallback allreduce;" ,
832
+ op -> o_name );
833
+ goto fallback ;
834
+ }
835
+
836
+ rc = hcoll_collectives .coll_reduce_scatter ((void * )sbuf , rbuf , (int * )rcounts ,
837
+ Dtype , Op , hcoll_module -> hcoll_context );
838
+ if (HCOLL_SUCCESS != rc ){
839
+ fallback :
840
+ HCOL_VERBOSE (20 ,"RUNNING FALLBACK ALLREDUCE" );
841
+ rc = hcoll_module -> previous_reduce_scatter (sbuf ,rbuf ,
842
+ rcounts ,dtype ,op ,
843
+ comm , hcoll_module -> previous_allreduce_module );
844
+ }
845
+ return rc ;
846
+ }
847
+ #endif
0 commit comments