Skip to content

Commit 02e00dc

Browse files
authored
[java] Adding ability to load a model from a memory mapped byte buffer (microsoft#20062)
### Description Adds support for constructing an `OrtSession` from a `java.nio.ByteBuffer`. These buffers can be memory mapped from files which means there doesn't need to be copies of the model protobuf held in Java, reducing peak memory usage during session construction. ### Motivation and Context Reduces memory usage on model construction by not requiring as many copies on the Java side. Should help with microsoft#19599.
1 parent c63dd02 commit 02e00dc

File tree

4 files changed

+138
-2
lines changed

4 files changed

+138
-2
lines changed

java/src/main/java/ai/onnxruntime/OrtEnvironment.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
/*
2-
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
package ai.onnxruntime;
66

77
import ai.onnxruntime.OrtSession.SessionOptions;
88
import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState;
99
import java.io.IOException;
10+
import java.nio.ByteBuffer;
1011
import java.util.EnumSet;
1112
import java.util.Objects;
1213
import java.util.logging.Logger;
@@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption
236237
return new OrtSession(this, modelPath, allocator, options);
237238
}
238239

240+
/**
241+
* Create a session using the specified {@link SessionOptions}, model and the default memory
242+
* allocator.
243+
*
244+
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
245+
* @param options The session options.
246+
* @return An {@link OrtSession} with the specified model.
247+
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
248+
*/
249+
public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options)
250+
throws OrtException {
251+
return createSession(modelBuffer, defaultAllocator, options);
252+
}
253+
254+
/**
255+
* Create a session using the default {@link SessionOptions}, model and the default memory
256+
* allocator.
257+
*
258+
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
259+
* @return An {@link OrtSession} with the specified model.
260+
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
261+
*/
262+
public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException {
263+
return createSession(modelBuffer, new OrtSession.SessionOptions());
264+
}
265+
266+
/**
267+
* Create a session using the specified {@link SessionOptions} and model buffer.
268+
*
269+
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
270+
* @param allocator The memory allocator to use.
271+
* @param options The session options.
272+
* @return An {@link OrtSession} with the specified model.
273+
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
274+
*/
275+
OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
276+
throws OrtException {
277+
Objects.requireNonNull(modelBuffer, "model array must not be null");
278+
if (modelBuffer.remaining() == 0) {
279+
throw new OrtException("Invalid model buffer, no elements remaining.");
280+
} else if (!modelBuffer.isDirect()) {
281+
throw new OrtException("ByteBuffer is not direct.");
282+
}
283+
return new OrtSession(this, modelBuffer, allocator, options);
284+
}
285+
239286
/**
240287
* Create a session using the specified {@link SessionOptions}, model and the default memory
241288
* allocator.

java/src/main/java/ai/onnxruntime/OrtSession.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import ai.onnxruntime.providers.OrtFlags;
1212
import ai.onnxruntime.providers.OrtTensorRTProviderOptions;
1313
import java.io.IOException;
14+
import java.nio.ByteBuffer;
1415
import java.util.ArrayList;
1516
import java.util.Arrays;
1617
import java.util.Collections;
@@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable {
9495
allocator);
9596
}
9697

98+
/**
99+
* Creates a session reading the model from the supplied byte buffer.
100+
*
101+
* <p>Must be a direct byte buffer.
102+
*
103+
* @param env The environment.
104+
* @param modelBuffer The model protobuf as a byte buffer.
105+
* @param allocator The allocator to use.
106+
* @param options Session configuration options.
107+
* @throws OrtException If the model was corrupted or some other error occurred in native code.
108+
*/
109+
OrtSession(
110+
OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
111+
throws OrtException {
112+
this(
113+
createSession(
114+
OnnxRuntime.ortApiHandle,
115+
env.getNativeHandle(),
116+
modelBuffer,
117+
modelBuffer.position(),
118+
modelBuffer.remaining(),
119+
options.getNativeHandle()),
120+
allocator);
121+
}
122+
97123
/**
98124
* Private constructor to build the Java object wrapped around a native session.
99125
*
@@ -514,6 +540,15 @@ private static native long createSession(
514540
private static native long createSession(
515541
long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException;
516542

543+
private static native long createSession(
544+
long apiHandle,
545+
long envHandle,
546+
ByteBuffer modelBuffer,
547+
int bufferPos,
548+
int bufferSize,
549+
long optsHandle)
550+
throws OrtException;
551+
517552
private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException;
518553

519554
private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle)

java/src/main/native/ai_onnxruntime_OrtSession.c

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
#include <jni.h>
@@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
4848
return (jlong)session;
4949
}
5050

51+
/*
52+
* Class: ai_onnxruntime_OrtSession
53+
* Method: createSession
54+
* Signature: (JJLjava/nio/ByteBuffer;IIJ)J
55+
*/
56+
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) {
57+
(void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
58+
const OrtApi* api = (const OrtApi*)apiHandle;
59+
OrtEnv* env = (OrtEnv*)envHandle;
60+
OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle;
61+
OrtSession* session = NULL;
62+
63+
// Extract the buffer
64+
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
65+
// Increment by bufferPos bytes
66+
bufferArr = bufferArr + bufferPos;
67+
68+
// Create the session
69+
checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session));
70+
71+
return (jlong)session;
72+
}
73+
5174
/*
5275
* Class: ai_onnxruntime_OrtSession
5376
* Method: createSession

java/src/test/java/ai/onnxruntime/InferenceTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
2121
import java.io.File;
2222
import java.io.IOException;
23+
import java.io.RandomAccessFile;
2324
import java.nio.ByteBuffer;
2425
import java.nio.ByteOrder;
2526
import java.nio.FloatBuffer;
2627
import java.nio.LongBuffer;
28+
import java.nio.MappedByteBuffer;
29+
import java.nio.channels.FileChannel;
30+
import java.nio.channels.FileChannel.MapMode;
2731
import java.nio.file.Files;
2832
import java.nio.file.Path;
2933
import java.nio.file.Paths;
@@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException {
338342
}
339343
}
340344

345+
@Test
346+
public void createSessionFromByteBuffer() throws IOException, OrtException {
347+
Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");
348+
try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r");
349+
FileChannel channel = file.getChannel()) {
350+
MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size());
351+
try (OrtSession.SessionOptions options = new SessionOptions();
352+
OrtSession session = env.createSession(modelBuffer, options)) {
353+
assertNotNull(session);
354+
assertEquals(1, session.getNumInputs()); // 1 input node
355+
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
356+
assertNotNull(inputInfoList);
357+
assertEquals(1, inputInfoList.size());
358+
NodeInfo input = inputInfoList.get("data_0");
359+
assertEquals("data_0", input.getName()); // input node name
360+
assertTrue(input.getInfo() instanceof TensorInfo);
361+
TensorInfo inputInfo = (TensorInfo) input.getInfo();
362+
assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
363+
int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
364+
assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
365+
for (int i = 0; i < expectedInputDimensions.length; i++) {
366+
assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
367+
}
368+
}
369+
}
370+
}
371+
341372
@Test
342373
public void createSessionFromByteArray() throws IOException, OrtException {
343374
Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");

0 commit comments

Comments
 (0)