Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -493,21 +493,22 @@ public static <T extends Buffer> TensorInfo constructFromSparseTensor(
* @throws OrtException If the array has a zero dimension, or is ragged.
*/
private static void extractShape(long[] shape, int curDim, Object obj) throws OrtException {
if (shape.length != curDim) {
int curLength = Array.getLength(obj);
if (curLength == 0) {
throw new OrtException(
"Supplied array has a zero dimension at "
+ curDim
+ ", all dimensions must be positive");
} else if (shape[curDim] == 0L) {
shape[curDim] = curLength;
} else if (shape[curDim] != curLength) {
throw new OrtException(
"Supplied array is ragged, expected " + shape[curDim] + ", found " + curLength);
}
int curLength = Array.getLength(obj);
if (curLength == 0) {
throw new OrtException(
"Supplied array has a zero dimension at " + curDim + ", all dimensions must be positive");
} else if (shape[curDim] == 0L) {
shape[curDim] = curLength;
} else if (shape[curDim] != curLength) {
throw new OrtException(
"Supplied array is ragged, expected " + shape[curDim] + ", found " + curLength);
}
int nextDim = curDim + 1;
// Avoid traversing the entire array (autoboxing its values) when the next dimension is equal
// to the shape's length
if (shape.length != nextDim) {
for (int i = 0; i < curLength; i++) {
extractShape(shape, curDim + 1, Array.get(obj, i));
extractShape(shape, nextDim, Array.get(obj, i));
}
}
}
Expand Down
63 changes: 63 additions & 0 deletions java/src/test/java/ai/onnxruntime/TensorInfoTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package ai.onnxruntime;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class TensorInfoTest {
@Test
public void testConstructFromJavaArray_UnexpectedType() {
Object obj = new Object();
Throwable t =
Assertions.assertThrows(OrtException.class, () -> TensorInfo.constructFromJavaArray(obj));
Assertions.assertEquals(
"Cannot convert class java.lang.Object to a OnnxTensor.", t.getMessage());
}

@Test
public void testConstructFromJavaArray_ScalarType() throws OrtException {
float obj = 1.0f;
TensorInfo tensorInfo = TensorInfo.constructFromJavaArray(obj);
Assertions.assertArrayEquals(new long[0], tensorInfo.shape);
Assertions.assertEquals(OnnxJavaType.FLOAT, tensorInfo.type);
Assertions.assertEquals(
TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensorInfo.onnxType);
}

@Test
public void testConstructFromJavaArray_1DArrayOfNonPrimitiveNorString() {
Object[] obj = new Object[] {new Object(), new Object()};
Throwable t =
Assertions.assertThrows(OrtException.class, () -> TensorInfo.constructFromJavaArray(obj));
Assertions.assertEquals(
"Cannot create an OnnxTensor from a base type of class java.lang.Object", t.getMessage());
}

@Test
public void testConstructFromJavaArray_NineDimensions() {
float[][][][][][][][][] obj = new float[1][1][1][1][1][1][1][1][1];
Throwable t =
Assertions.assertThrows(OrtException.class, () -> TensorInfo.constructFromJavaArray(obj));
Assertions.assertEquals(
"Cannot create an OnnxTensor with more than 8 dimensions. Found 9 dimensions.",
t.getMessage());
}

@Test
public void testConstructFromJavaArray_RaggedArray() {
float[][] obj = new float[][] {new float[1], new float[2]};
Throwable t =
Assertions.assertThrows(OrtException.class, () -> TensorInfo.constructFromJavaArray(obj));
Assertions.assertEquals("Supplied array is ragged, expected 1, found 2", t.getMessage());
}

@Test
public void testConstructFromJavaArray_ExtractRecursive() throws OrtException {
float[][][] obj = new float[3][2][3];
TensorInfo tensorInfo = TensorInfo.constructFromJavaArray(obj);

Assertions.assertArrayEquals(new long[] {3, 2, 3}, tensorInfo.shape);
Assertions.assertEquals(OnnxJavaType.FLOAT, tensorInfo.type);
Assertions.assertEquals(
TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensorInfo.onnxType);
}
}