Skip to content

Commit 6d7235b

Browse files
authored
[Java] Exposing SessionOptions.SetDeterministicCompute (microsoft#18998)
### Description Exposes `SetDeterministicCompute` in Java, added to the C API by microsoft#18944. ### Motivation and Context Parity between C and Java APIs.
1 parent 02e00dc commit 6d7235b

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

java/src/main/java/ai/onnxruntime/OrtSession.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue)
942942
OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue);
943943
}
944944

945+
/**
946+
* Set whether to use deterministic compute.
947+
*
948+
* <p>Default is false. If set to true, this will enable deterministic compute for GPU kernels
949+
* where possible. Note that this most likely will have a performance cost.
950+
*
951+
* @param value Should the compute be deterministic?
952+
* @throws OrtException If there was an error in native code.
953+
*/
954+
public void setDeterministicCompute(boolean value) throws OrtException {
955+
checkClosed();
956+
setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value);
957+
}
958+
945959
/**
946960
* Disables the per session thread pools. Must be used in conjunction with an environment
947961
* containing global thread pools.
@@ -1327,6 +1341,9 @@ private native void registerCustomOpsUsingFunction(
13271341

13281342
private native void closeOptions(long apiHandle, long nativeHandle);
13291343

1344+
private native void setDeterministicCompute(
1345+
long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException;
1346+
13301347
private native void addFreeDimensionOverrideByName(
13311348
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
13321349
throws OrtException;

java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes
259259
checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel));
260260
}
261261

262+
/*
263+
* Class: ai_onnxruntime_OrtSession_SessionOptions
264+
* Method: setDeterministicCompute
265+
* Signature: (JJZ)V
266+
*/
267+
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute
268+
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) {
269+
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
270+
const OrtApi* api = (const OrtApi*)apiHandle;
271+
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
272+
checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic));
273+
}
274+
262275
/*
263276
* Class: ai_onnxruntime_OrtSession_SessionOptions
264277
* Method: registerCustomOpLibrary

java/src/test/java/ai/onnxruntime/InferenceTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException {
12631263
options.setLoggerId("monkeys");
12641264
options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
12651265
options.setSessionLogVerbosityLevel(5);
1266+
options.setDeterministicCompute(true);
12661267
Map<String, String> configEntries = options.getConfigEntries();
12671268
assertTrue(configEntries.isEmpty());
12681269
options.addConfigEntry("key", "value");

0 commit comments

Comments
 (0)