@@ -891,3 +891,134 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
891
891
cleanup_and_return :
892
892
return err ;
893
893
}
894
+
895
+ /*
896
+ * ompi_coll_base_bcast_intra_scatter_allgather_ring
897
+ *
898
+ * Function: Bcast using a binomial tree scatter followed by a ring allgather.
899
+ * Accepts: Same arguments as MPI_Bcast
900
+ * Returns: MPI_SUCCESS or error code
901
+ *
902
+ * Limitations: count >= comm_size
903
+ * Time complexity: O(\alpha(\log(p) + p) + \beta*m((p-1)/p))
904
+ * Binomial tree scatter: \alpha\log(p) + \beta*m((p-1)/p)
905
+ * Ring allgather: 2(p-1)(\alpha + m/p\beta)
906
+ *
907
+ * Example, p=8, count=8, root=0
908
+ * Binomial tree scatter Ring allgather: p - 1 steps
909
+ * 0: --+ --+ --+ [0*******] [0******7] [0*****67] [0****567] ... [01234567]
910
+ * 1: | 2| <-+ [*1******] [01******] [01*****7] [01****67] ... [01234567]
911
+ * 2: 4| <-+ --+ [**2*****] [*12*****] [012*****] [012****7] ... [01234567]
912
+ * 3: | <-+ [***3****] [**23****] [*123****] [0123****] ... [01234567]
913
+ * 4: <-+ --+ --+ [****4***] [***34***] [**234***] [*1234***] ... [01234567]
914
+ * 5: 2| <-+ [*****5**] [****45**] [***345**] [**2345**] ... [01234567]
915
+ * 6: <-+ --+ [******6*] [*****56*] [****456*] [***3456*] ... [01234567]
916
+ * 7: <-+ [*******7] [******67] [*****567] [****4567] ... [01234567]
917
+ */
918
+ int ompi_coll_base_bcast_intra_scatter_allgather_ring (
919
+ void * buf , int count , struct ompi_datatype_t * datatype , int root ,
920
+ struct ompi_communicator_t * comm , mca_coll_base_module_t * module ,
921
+ uint32_t segsize )
922
+ {
923
+ int err = MPI_SUCCESS ;
924
+ ptrdiff_t lb , extent ;
925
+ size_t datatype_size ;
926
+ MPI_Status status ;
927
+ ompi_datatype_get_extent (datatype , & lb , & extent );
928
+ ompi_datatype_type_size (datatype , & datatype_size );
929
+ int comm_size = ompi_comm_size (comm );
930
+ int rank = ompi_comm_rank (comm );
931
+
932
+ OPAL_OUTPUT ((ompi_coll_base_framework .framework_output ,
933
+ "coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d" ,
934
+ rank , comm_size ));
935
+ if (comm_size < 2 || datatype_size == 0 )
936
+ return MPI_SUCCESS ;
937
+
938
+ if (count < comm_size ) {
939
+ OPAL_OUTPUT ((ompi_coll_base_framework .framework_output ,
940
+ "coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d "
941
+ "count %d switching to basic linear bcast" ,
942
+ rank , comm_size , count ));
943
+ return ompi_coll_base_bcast_intra_basic_linear (buf , count , datatype ,
944
+ root , comm , module );
945
+ }
946
+
947
+ int vrank = (rank - root + comm_size ) % comm_size ;
948
+ int recv_count = 0 , send_count = 0 ;
949
+ int scatter_count = (count + comm_size - 1 ) / comm_size ; /* ceil(count / comm_size) */
950
+ int curr_count = (rank == root ) ? count : 0 ;
951
+
952
+ /* Scatter by binomial tree: receive data from parent */
953
+ int mask = 1 ;
954
+ while (mask < comm_size ) {
955
+ if (vrank & mask ) {
956
+ int parent = (rank - mask + comm_size ) % comm_size ;
957
+ /* Compute an upper bound on recv block size */
958
+ recv_count = count - vrank * scatter_count ;
959
+ if (recv_count <= 0 ) {
960
+ curr_count = 0 ;
961
+ } else {
962
+ /* Recv data from parent */
963
+ err = MCA_PML_CALL (recv ((char * )buf + (ptrdiff_t )vrank * scatter_count * extent ,
964
+ recv_count , datatype , parent ,
965
+ MCA_COLL_BASE_TAG_BCAST , comm , & status ));
966
+ if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
967
+ /* Get received count */
968
+ curr_count = (int )(status ._ucount / datatype_size );
969
+ }
970
+ break ;
971
+ }
972
+ mask <<= 1 ;
973
+ }
974
+
975
+ /* Scatter by binomial tree: send data to child processes */
976
+ mask >>= 1 ;
977
+ while (mask > 0 ) {
978
+ if (vrank + mask < comm_size ) {
979
+ send_count = curr_count - scatter_count * mask ;
980
+ if (send_count > 0 ) {
981
+ int child = (rank + mask ) % comm_size ;
982
+ err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
983
+ send_count , datatype , child ,
984
+ MCA_COLL_BASE_TAG_BCAST ,
985
+ MCA_PML_BASE_SEND_STANDARD , comm ));
986
+ if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
987
+ curr_count -= send_count ;
988
+ }
989
+ }
990
+ mask >>= 1 ;
991
+ }
992
+
993
+ /* Allgather by a ring algorithm */
994
+ int left = (rank - 1 + comm_size ) % comm_size ;
995
+ int right = (rank + 1 ) % comm_size ;
996
+ int send_block = vrank ;
997
+ int recv_block = (vrank - 1 + comm_size ) % comm_size ;
998
+
999
+ for (int i = 1 ; i < comm_size ; i ++ ) {
1000
+ recv_count = (scatter_count < count - recv_block * scatter_count ) ?
1001
+ scatter_count : count - recv_block * scatter_count ;
1002
+ if (recv_count < 0 )
1003
+ recv_count = 0 ;
1004
+ ptrdiff_t recv_offset = recv_block * scatter_count * extent ;
1005
+
1006
+ send_count = (scatter_count < count - send_block * scatter_count ) ?
1007
+ scatter_count : count - send_block * scatter_count ;
1008
+ if (send_count < 0 )
1009
+ send_count = 0 ;
1010
+ ptrdiff_t send_offset = send_block * scatter_count * extent ;
1011
+
1012
+ err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1013
+ datatype , right , MCA_COLL_BASE_TAG_BCAST ,
1014
+ (char * )buf + recv_offset , recv_count ,
1015
+ datatype , left , MCA_COLL_BASE_TAG_BCAST ,
1016
+ comm , MPI_STATUS_IGNORE , rank );
1017
+ if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
1018
+ send_block = recv_block ;
1019
+ recv_block = (recv_block - 1 + comm_size ) % comm_size ;
1020
+ }
1021
+
1022
+ cleanup_and_return :
1023
+ return err ;
1024
+ }
0 commit comments