Skip to content

Commit 2b85657

Browse files
authored
Merge pull request #6699 from t-kurita/pr/java-alltoallw-arrays
java: Fix compilation error in allToAllw using Java arrays
2 parents e3eb6b5 + 7ece564 commit 2b85657

File tree

4 files changed

+314
-36
lines changed

4 files changed

+314
-36
lines changed

ompi/mpi/java/c/mpiJava.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
* All rights reserved.
1212
* Copyright (c) 2015 Los Alamos National Security, LLC. All rights
1313
* reserved.
14+
* Copyright (c) 2019 FUJITSU LIMITED. All rights reserved.
1415
* $COPYRIGHT$
1516
*
1617
* Additional copyrights may follow
@@ -94,6 +95,15 @@ void ompi_java_getReadPtrv(
9495
jobject buf, jboolean db, int off, int *counts, int *displs,
9596
int size, int rank, MPI_Datatype type, int baseType);
9697

98+
/* Gets a buffer pointer for reading, but it
99+
* 'size' is the number of processes.
100+
* if rank == -1 it copies all data from Java.
101+
* if rank != -1 it only copies from Java the rank data. */
102+
void ompi_java_getReadPtrw(
103+
void **ptr, ompi_java_buffer_t **item, JNIEnv *env,
104+
jobject buf, jboolean db, int *offs, int *counts, int *displs,
105+
int size, int rank, MPI_Datatype *types, int *baseTypes);
106+
97107
/* Releases a buffer used for reading. */
98108
void ompi_java_releaseReadPtr(
99109
void *ptr, ompi_java_buffer_t *item, jobject buf, jboolean db);
@@ -109,6 +119,12 @@ void ompi_java_getWritePtrv(
109119
void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
110120
jboolean db, int *counts, int *displs, int size, MPI_Datatype type);
111121

122+
/* Gets a buffer pointer for writing.
123+
* 'size' is the number of processes. */
124+
void ompi_java_getWritePtrw(
125+
void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
126+
jboolean db, int *counts, int *displs, int size, MPI_Datatype *types);
127+
112128
/* Releases a buffer used for writing.
113129
* It copies data to Java. */
114130
void ompi_java_releaseWritePtr(
@@ -123,6 +139,14 @@ void ompi_java_releaseWritePtrv(
123139
jobject buf, jboolean db, int off, int *counts, int *displs,
124140
int size, MPI_Datatype type, int baseType);
125141

142+
/* Releases a buffer used for writing.
143+
* It copies data to Java.
144+
* 'size' is the number of processes. */
145+
void ompi_java_releaseWritePtrw(
146+
void *ptr, ompi_java_buffer_t *item, JNIEnv *env,
147+
jobject buf, jboolean db, int *offs, int *counts, int *displs,
148+
int size, MPI_Datatype *types, int *baseTypes);
149+
126150
void ompi_java_setStaticLongField(JNIEnv *env, jclass c,
127151
char *field, jlong value);
128152

ompi/mpi/java/c/mpi_Comm.c

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* and Technology (RIST). All rights reserved.
1414
* Copyright (c) 2016 Los Alamos National Security, LLC. All rights
1515
* reserved.
16-
* Copyright (c) 2017 FUJITSU LIMITED. All rights reserved.
16+
* Copyright (c) 2017-2019 FUJITSU LIMITED. All rights reserved.
1717
* $COPYRIGHT$
1818
*
1919
* Additional copyrights may follow
@@ -1612,39 +1612,67 @@ JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllv(
16121612
}
16131613

16141614
JNIEXPORT void JNICALL Java_mpi_Comm_allToAllw(
1615-
JNIEnv *env, jobject jthis, jlong jComm,
1616-
jobject sendBuf, jintArray sCount, jintArray sDispls, jlongArray sTypes,
1617-
jobject recvBuf, jintArray rCount, jintArray rDispls, jlongArray rTypes)
1615+
JNIEnv *env, jobject jthis, jlong jComm,
1616+
jobject sBuf, jboolean sdb, jintArray sOffs, jintArray sCount,
1617+
jintArray sDispls, jlongArray sTypes, jintArray sBtypes,
1618+
jobject rBuf, jboolean rdb, jintArray rOffs, jintArray rCount,
1619+
jintArray rDispls, jlongArray rTypes, jintArray rBtypes)
16181620
{
1619-
MPI_Comm comm = (MPI_Comm)jComm;
1621+
MPI_Comm comm = (MPI_Comm)jComm;
16201622

1621-
jlong* jSTypes, *jRTypes;
1622-
MPI_Datatype *cSTypes, *cRTypes;
1623+
int inter = isInter(env, comm),
1624+
size = getSize(env, comm, inter);
16231625

1624-
ompi_java_getDatatypeArray(env, sTypes, &jSTypes, &cSTypes);
1625-
ompi_java_getDatatypeArray(env, rTypes, &jRTypes, &cRTypes);
1626+
jlong* jSTypes, *jRTypes;
1627+
MPI_Datatype *cSTypes, *cRTypes;
16261628

1627-
jint *jSCount, *jRCount, *jSDispls, *jRDispls;
1628-
int *cSCount, *cRCount, *cSDispls, *cRDispls;
1629-
ompi_java_getIntArray(env, sCount, &jSCount, &cSCount);
1630-
ompi_java_getIntArray(env, rCount, &jRCount, &cRCount);
1631-
ompi_java_getIntArray(env, sDispls, &jSDispls, &cSDispls);
1632-
ompi_java_getIntArray(env, rDispls, &jRDispls, &cRDispls);
1629+
ompi_java_getDatatypeArray(env, sTypes, &jSTypes, &cSTypes);
1630+
ompi_java_getDatatypeArray(env, rTypes, &jRTypes, &cRTypes);
16331631

1634-
void *sPtr = ompi_java_getDirectBufferAddress(env, sendBuf),
1635-
*rPtr = ompi_java_getDirectBufferAddress(env, recvBuf);
1632+
jint *jSCount, *jRCount, *jSDispls, *jRDispls;
1633+
int *cSCount, *cRCount, *cSDispls, *cRDispls;
1634+
jint *jSBtypes, *jRBtypes;
1635+
int *cSBtypes, *cRBtypes;
1636+
jint *jSOffs, *jROffs;
1637+
int *cSOffs, *cROffs;
16361638

1637-
int rc = MPI_Alltoallw(
1638-
sPtr, cSCount, cSDispls, cSTypes,
1639-
rPtr, cRCount, cRDispls, cRTypes, comm);
1639+
ompi_java_getIntArray(env, sCount, &jSCount, &cSCount);
1640+
ompi_java_getIntArray(env, rCount, &jRCount, &cRCount);
1641+
ompi_java_getIntArray(env, sDispls, &jSDispls, &cSDispls);
1642+
ompi_java_getIntArray(env, rDispls, &jRDispls, &cRDispls);
1643+
ompi_java_getIntArray(env, sBtypes, &jSBtypes, &cSBtypes);
1644+
ompi_java_getIntArray(env, rBtypes, &jRBtypes, &cRBtypes);
1645+
ompi_java_getIntArray(env, sOffs, &jSOffs, &cSOffs);
1646+
ompi_java_getIntArray(env, rOffs, &jROffs, &cROffs);
16401647

1641-
ompi_java_exceptionCheck(env, rc);
1642-
ompi_java_forgetIntArray(env, sCount, jSCount, cSCount);
1643-
ompi_java_forgetIntArray(env, rCount, jRCount, cRCount);
1644-
ompi_java_forgetIntArray(env, sDispls, jSDispls, cSDispls);
1645-
ompi_java_forgetIntArray(env, rDispls, jRDispls, cRDispls);
1646-
ompi_java_forgetDatatypeArray(env, sTypes, jSTypes, cSTypes);
1647-
ompi_java_forgetDatatypeArray(env, rTypes, jRTypes, cRTypes);
1648+
void *sPtr, *rPtr;
1649+
ompi_java_buffer_t *sItem, *rItem;
1650+
1651+
ompi_java_getReadPtrw(&sPtr, &sItem, env, sBuf, sdb, cSOffs,
1652+
cSCount, cSDispls, size, -1, cSTypes, cSBtypes);
1653+
ompi_java_getWritePtrw(&rPtr, &rItem, env, rBuf, rdb,
1654+
cRCount, cRDispls, size, cRTypes);
1655+
1656+
int rc = MPI_Alltoallw(sPtr, cSCount, cSDispls, cSTypes,
1657+
rPtr, cRCount, cRDispls, cRTypes, comm);
1658+
1659+
ompi_java_exceptionCheck(env, rc);
1660+
ompi_java_releaseReadPtr(sPtr, sItem, sBuf, sdb);
1661+
1662+
ompi_java_releaseWritePtrw(rPtr, rItem, env, rBuf, rdb, cROffs,
1663+
cRCount, cRDispls, size, cRTypes, cRBtypes);
1664+
1665+
ompi_java_exceptionCheck(env, rc);
1666+
ompi_java_forgetIntArray(env, sCount, jSCount, cSCount);
1667+
ompi_java_forgetIntArray(env, rCount, jRCount, cRCount);
1668+
ompi_java_forgetIntArray(env, sDispls, jSDispls, cSDispls);
1669+
ompi_java_forgetIntArray(env, rDispls, jRDispls, cRDispls);
1670+
ompi_java_forgetIntArray(env, sBtypes, jSBtypes, cSBtypes);
1671+
ompi_java_forgetIntArray(env, rBtypes, jRBtypes, cRBtypes);
1672+
ompi_java_forgetIntArray(env, sOffs, jSOffs, cSOffs);
1673+
ompi_java_forgetIntArray(env, rOffs, jROffs, cROffs);
1674+
ompi_java_forgetDatatypeArray(env, sTypes, jSTypes, cSTypes);
1675+
ompi_java_forgetDatatypeArray(env, rTypes, jRTypes, cRTypes);
16481676
}
16491677

16501678
JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllw(

ompi/mpi/java/c/mpi_MPI.c

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* Copyright (c) 2015 Research Organization for Information Science
1818
* and Technology (RIST). All rights reserved.
1919
* Copyright (c) 2016-2017 IBM Corporation. All rights reserved.
20+
* Copyright (c) 2019 FUJITSU LIMITED. All rights reserved.
2021
* $COPYRIGHT$
2122
*
2223
* Additional copyrights may follow
@@ -672,6 +673,39 @@ static void* getReadPtrvRank(
672673
return ptr;
673674
}
674675

676+
static void* getReadPtrwRank(
677+
ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
678+
int *offsets, int *counts, int *displs, int size,
679+
int rank, MPI_Datatype *types, int *baseTypes)
680+
{
681+
int extent = getTypeExtent(env, types[rank]),
682+
length = getCountv(counts, displs, size);
683+
void *ptr = getBuffer(env, item, length);
684+
int rootOff = offsets[rank] + displs[rank];
685+
686+
if(opal_datatype_is_contiguous_memory_layout(&types[rank]->super, counts[rank]))
687+
{
688+
int rootLength = extent * counts[rank];
689+
void *rootPtr = (char*)ptr + displs[rank];
690+
getArrayRegion(env, buf, baseTypes[rank], rootOff, rootLength, rootPtr);
691+
}
692+
else
693+
{
694+
void *inBuf, *inBase;
695+
inBuf = ompi_java_getArrayCritical(&inBase, env, buf, rootOff);
696+
697+
int rc = opal_datatype_copy_content_same_ddt(
698+
&types[rank]->super, counts[rank], ptr, inBuf);
699+
700+
ompi_java_exceptionCheck(env,
701+
rc==OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR);
702+
703+
(*env)->ReleasePrimitiveArrayCritical(env, buf, inBase, JNI_ABORT);
704+
}
705+
706+
return ptr;
707+
}
708+
675709
static void* getReadPtrvAll(
676710
ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
677711
int offset, int *counts, int *displs, int size,
@@ -716,6 +750,49 @@ static void* getReadPtrvAll(
716750
return ptr;
717751
}
718752

753+
static void* getReadPtrwAll(
754+
ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
755+
int *offsets, int *counts, int *displs, int size,
756+
MPI_Datatype *types, int *baseTypes)
757+
{
758+
759+
int length = getCountv(counts, displs, size);
760+
void *ptr = getBuffer(env, item, length);
761+
762+
for(int i = 0; i < size; i++)
763+
{
764+
int extent = getTypeExtent(env, types[i]);
765+
766+
if(opal_datatype_is_contiguous_memory_layout(&types[i]->super, 2))
767+
{
768+
int iOff = offsets[i] + displs[i],
769+
iLen = extent * counts[i];
770+
void *iPtr = (char*)ptr + displs[i];
771+
getArrayRegion(env, buf, baseTypes[i], iOff, iLen, iPtr);
772+
}
773+
else
774+
{
775+
void *bufPtr, *bufBase;
776+
bufPtr = ompi_java_getArrayCritical(&bufBase, env, buf, offsets[i]);
777+
778+
int iOff = displs[i];
779+
char *iBuf = iOff + (char*)bufPtr,
780+
*iPtr = iOff + (char*)ptr;
781+
782+
int rc = opal_datatype_copy_content_same_ddt(
783+
&types[i]->super, counts[i], iPtr, iBuf);
784+
785+
ompi_java_exceptionCheck(env,
786+
rc==OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR);
787+
788+
(*env)->ReleasePrimitiveArrayCritical(env, buf, bufBase, JNI_ABORT);
789+
}
790+
791+
}
792+
793+
return ptr;
794+
}
795+
719796
static void* getWritePtr(ompi_java_buffer_t **item, JNIEnv *env,
720797
int count, MPI_Datatype type)
721798
{
@@ -735,6 +812,14 @@ static void* getWritePtrv(ompi_java_buffer_t **item, JNIEnv *env,
735812
return getBuffer(env, item, length);
736813
}
737814

815+
static void* getWritePtrw(ompi_java_buffer_t **item, JNIEnv *env,
816+
int *counts, int *displs, int size, MPI_Datatype *types)
817+
{
818+
int length = getCountv(counts, displs, size);
819+
820+
return getBuffer(env, item, length);
821+
}
822+
738823
void ompi_java_getReadPtr(
739824
void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
740825
jboolean db, int offset, int count, MPI_Datatype type, int baseType)
@@ -810,6 +895,39 @@ void ompi_java_getReadPtrv(
810895
}
811896
}
812897

898+
void ompi_java_getReadPtrw(
899+
void **ptr, ompi_java_buffer_t **item, JNIEnv *env,
900+
jobject buf, jboolean db, int *offsets, int *counts, int *displs,
901+
int size, int rank, MPI_Datatype *types, int *baseTypes)
902+
{
903+
int i;
904+
905+
if(buf == NULL)
906+
{
907+
/* Allow NULL buffers to send/recv 0 items as control messages. */
908+
*ptr = NULL;
909+
*item = NULL;
910+
}
911+
else if(db)
912+
{
913+
for(i = 0; i < size; i++){
914+
assert(offsets[i] == 0);
915+
}
916+
*ptr = (*env)->GetDirectBufferAddress(env, buf);
917+
*item = NULL;
918+
}
919+
else if(rank == -1)
920+
{
921+
*ptr = getReadPtrwAll(item, env, buf, offsets, counts,
922+
displs, size, types, baseTypes);
923+
}
924+
else
925+
{
926+
*ptr = getReadPtrwRank(item, env, buf, offsets, counts,
927+
displs, size, rank, types, baseTypes);
928+
}
929+
}
930+
813931
void ompi_java_releaseReadPtr(
814932
void *ptr, ompi_java_buffer_t *item, jobject buf, jboolean db)
815933
{
@@ -859,6 +977,27 @@ void ompi_java_getWritePtrv(
859977
}
860978
}
861979

980+
void ompi_java_getWritePtrw(
981+
void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf,
982+
jboolean db, int *counts, int *displs, int size, MPI_Datatype *types)
983+
{
984+
if(buf == NULL)
985+
{
986+
/* Allow NULL buffers to send/recv 0 items as control messages. */
987+
*ptr = NULL;
988+
*item = NULL;
989+
}
990+
else if(db)
991+
{
992+
*ptr = (*env)->GetDirectBufferAddress(env, buf);
993+
*item = NULL;
994+
}
995+
else
996+
{
997+
*ptr = getWritePtrw(item, env, counts, displs, size, types);
998+
}
999+
}
1000+
8621001
void ompi_java_releaseWritePtr(
8631002
void *ptr, ompi_java_buffer_t *item, JNIEnv *env, jobject buf,
8641003
jboolean db, int offset, int count, MPI_Datatype type, int baseType)
@@ -933,6 +1072,49 @@ void ompi_java_releaseWritePtrv(
9331072
releaseBuffer(ptr, item);
9341073
}
9351074

1075+
void ompi_java_releaseWritePtrw(
1076+
void *ptr, ompi_java_buffer_t *item, JNIEnv *env,
1077+
jobject buf, jboolean db, int *offsets, int *counts, int *displs,
1078+
int size, MPI_Datatype *types, int *baseTypes)
1079+
{
1080+
if(db || !buf || !ptr)
1081+
return;
1082+
1083+
int i;
1084+
1085+
for(i = 0; i < size; i++)
1086+
{
1087+
int extent = getTypeExtent(env, types[i]);
1088+
1089+
if(opal_datatype_is_contiguous_memory_layout(&types[i]->super, 2))
1090+
{
1091+
int iOff = offsets[i] + displs[i],
1092+
iLen = extent * counts[i];
1093+
void *iPtr = (char*)ptr + displs[i];
1094+
setArrayRegion(env, buf, baseTypes[i], iOff, iLen, iPtr);
1095+
}
1096+
else
1097+
{
1098+
void *bufPtr, *bufBase;
1099+
1100+
bufPtr = ompi_java_getArrayCritical(&bufBase, env, buf, offsets[i]);
1101+
int iOff = displs[i];
1102+
char *iBuf = iOff + (char*)bufPtr,
1103+
*iPtr = iOff + (char*)ptr;
1104+
1105+
int rc = opal_datatype_copy_content_same_ddt(
1106+
&types[i]->super, counts[i], iBuf, iPtr);
1107+
1108+
ompi_java_exceptionCheck(env,
1109+
rc==OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR);
1110+
1111+
(*env)->ReleasePrimitiveArrayCritical(env, buf, bufBase, 0);
1112+
}
1113+
1114+
}
1115+
releaseBuffer(ptr, item);
1116+
}
1117+
9361118
jobject ompi_java_Integer_valueOf(JNIEnv *env, jint i)
9371119
{
9381120
return (*env)->CallStaticObjectMethod(env,

0 commit comments

Comments
 (0)