Skip to content

Commit aa7d6b0

Browse files
committed
python changes
1 parent b4cb6cd commit aa7d6b0

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

python/pyarrow/tests/test_extension_type.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import pytest
2727
try:
2828
import numpy as np
29-
from numpy.lib.stride_tricks import as_strided
3029
except ImportError:
3130
np = None
3231
import pyarrow as pa
@@ -1733,13 +1732,21 @@ def test_variable_shape_tensor_array_from_numpy(value_type):
17331732
pa.VariableShapeTensorArray.from_numpy_ndarray([arr.astype(np.int32()), arr])
17341733

17351734
flat_arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=value_type)
1735+
bw = value_type.itemsize
17361736

17371737
arr = flat_arr.reshape(1, 3, 4)
17381738
tensor_array_from_numpy = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17391739
assert tensor_array_from_numpy.type.ndim == 3
17401740
assert tensor_array_from_numpy.type.permutation == [0, 1, 2]
17411741
assert tensor_array_from_numpy[0].to_tensor() == pa.Tensor.from_numpy(arr)
17421742

1743+
arr = as_strided(flat_arr, shape=(1, 2, 3, 2),
1744+
strides=(bw * 12, bw * 6, bw, bw * 3))
1745+
tensor_array_from_numpy = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
1746+
assert tensor_array_from_numpy.type.ndim == 4
1747+
assert tensor_array_from_numpy.type.permutation == [0, 1, 3, 2]
1748+
assert tensor_array_from_numpy[0].to_tensor() == pa.Tensor.from_numpy(arr)
1749+
17431750
arr = flat_arr.reshape(1, 2, 3, 2)
17441751
result = pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17451752
expected = np.array(

0 commit comments

Comments
 (0)