Skip to content

Commit 22437b5

Browse files
authored
[java] Fix for OnnxTensor creation when passing in a ByteBuffer containing elements of a different type (microsoft#21774)
### Description Fixes a bug where the buffer offset and position was incorrectly computed if the user supplied a `ByteBuffer` to `createTensor` but set the type of the tensor to something other than `INT8`. This would be more common if the user was trying to load the initializers from a serialized representation and didn't want to bother with the type information (which is the case in microsoft#21321). ### Motivation and Context Partial fix for microsoft#21321. The remainder of the fix is to add a helper which allows users to load initializers out of an `onnx_data` file, but that will require adding protobuf as a dependency for the Java API to allow the parsing of an ONNX file separately from the native code. It might be nicer to put that functionality into ORT's C API so it can return the lengths & offsets of the initializers when provided with an ONNX file containing external initializers. We hit this kind of thing in Java more often than other languages as in Java models can be supplied as classpath resources which we can easily read, but not materialize on disk for the ORT native library to read.
1 parent f7bf5a1 commit 22437b5

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

java/src/main/java/ai/onnxruntime/OrtUtil.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
33
* Copyright (c) Microsoft Corporation. All rights reserved.
44
* Licensed under the MIT License.
55
*/
@@ -483,9 +483,12 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) {
483483
if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) {
484484
throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer");
485485
}
486+
// This buffer could be a ByteBuffer which is being used to carry data of another type, if so,
487+
// it's type.size should be 1 to compute the correct buffer size and offset.
488+
int elementSize = data instanceof ByteBuffer ? 1 : type.size;
486489
int bufferPos;
487-
long bufferSizeLong = data.remaining() * (long) type.size;
488-
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
490+
long bufferSizeLong = data.remaining() * (long) elementSize;
491+
if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) {
489492
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
490493
// on the JVM, so we check for something 8 elements below the maximum size which
491494
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
@@ -496,11 +499,11 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) {
496499
+ type);
497500
}
498501
// Now we know we're in range
499-
int bufferSize = data.remaining() * type.size;
502+
int bufferSize = data.remaining() * elementSize;
500503
Buffer tmp;
501504
if (data.isDirect()) {
502505
tmp = data;
503-
bufferPos = data.position() * type.size;
506+
bufferPos = data.position() * elementSize;
504507
} else {
505508
// Copy the data to a new direct buffer, then restore the state of the input.
506509
int origPosition = data.position();

java/src/test/java/ai/onnxruntime/OnnxTensorTest.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
package ai.onnxruntime;
@@ -218,6 +218,30 @@ public void testUint8Creation() throws OrtException {
218218
}
219219
}
220220

221+
@Test
222+
public void testByteBufferCreation() throws OrtException {
223+
OrtEnvironment env = OrtEnvironment.getEnvironment();
224+
ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder());
225+
FloatBuffer floatBuf = byteBuf.asFloatBuffer();
226+
floatBuf.put(1.0f);
227+
floatBuf.put(2.0f);
228+
floatBuf.put(3.0f);
229+
floatBuf.put(4.0f);
230+
floatBuf.put(5.0f);
231+
floatBuf.position(1);
232+
float[] expected = new float[floatBuf.remaining()];
233+
floatBuf.get(expected);
234+
floatBuf.position(1);
235+
byteBuf.position(4);
236+
try (OnnxTensor t =
237+
OnnxTensor.createTensor(
238+
env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) {
239+
Assertions.assertNotNull(t);
240+
float[] actual = (float[]) t.getValue();
241+
Assertions.assertArrayEquals(expected, actual);
242+
}
243+
}
244+
221245
@Test
222246
public void testEmptyTensor() throws OrtException {
223247
OrtEnvironment env = OrtEnvironment.getEnvironment();

0 commit comments

Comments
 (0)