Skip to content

Commit d5a548d

Browse files
committed
GH-38007: [Python] Add VariableShapeTensor Python bindings
Add PyArrow bindings for the VariableShapeTensor extension type, including VariableShapeTensorType, VariableShapeTensorArray, and VariableShapeTensorScalar with support for converting to/from NumPy tensors.
1 parent 0124d5b commit d5a548d

File tree

12 files changed

+1551
-20
lines changed

12 files changed

+1551
-20
lines changed

cpp/src/arrow/extension/tensor_extension_array_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,29 @@ TEST_F(TestVariableShapeTensorType, ComputeStrides) {
897897
ASSERT_TRUE(tensor->Equals(*t));
898898
}
899899

900+
TEST_F(TestVariableShapeTensorType, ComputeStridesWithNonTrivialPermutation) {
901+
auto permuted_ext_type = internal::checked_pointer_cast<VariableShapeTensorType>(
902+
variable_shape_tensor(value_type_, ndim_, {1, 0, 2}, dim_names_, uniform_shape_));
903+
904+
auto shape = ArrayFromJSON(shape_type_, "[[2,3,1]]");
905+
auto data = ArrayFromJSON(data_type_, "[[1,1,2,3,4,5]]");
906+
std::vector<std::shared_ptr<Field>> fields = {field("data", data_type_),
907+
field("shape", shape_type_)};
908+
ASSERT_OK_AND_ASSIGN(auto storage_arr, StructArray::Make({data, shape}, fields));
909+
auto ext_arr = ExtensionType::WrapArray(permuted_ext_type, storage_arr);
910+
auto ext_array = std::static_pointer_cast<VariableShapeTensorArray>(ext_arr);
911+
912+
ASSERT_OK_AND_ASSIGN(auto scalar, ext_array->GetScalar(0));
913+
ASSERT_OK_AND_ASSIGN(auto tensor,
914+
permuted_ext_type->MakeTensor(
915+
internal::checked_pointer_cast<ExtensionScalar>(scalar)));
916+
917+
ASSERT_EQ(tensor->shape(), (std::vector<int64_t>{3, 2, 1}));
918+
ASSERT_EQ(tensor->strides(), (std::vector<int64_t>{sizeof(int64_t), sizeof(int64_t) * 2,
919+
sizeof(int64_t)}));
920+
ASSERT_EQ(tensor->dim_names(), (std::vector<std::string>{"y", "x", "z"}));
921+
}
922+
900923
TEST_F(TestVariableShapeTensorType, ToString) {
901924
auto exact_ext_type =
902925
internal::checked_pointer_cast<VariableShapeTensorType>(ext_type_);

cpp/src/arrow/extension/variable_shape_tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ Result<std::shared_ptr<Tensor>> VariableShapeTensorType::MakeTensor(
262262
internal::Permute<std::string>(permutation, &dim_names);
263263
}
264264

265+
internal::Permute<int64_t>(permutation, &shape);
265266
ARROW_ASSIGN_OR_RAISE(
266267
auto strides, internal::ComputeStrides(ext_type.value_type(), shape, permutation));
267-
internal::Permute<int64_t>(permutation, &shape);
268268

269269
ARROW_ASSIGN_OR_RAISE(const auto buffer,
270270
internal::SliceTensorBuffer(*data_array, value_type, shape));

docs/source/python/api/arrays.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ may expose data type-specific methods or properties.
101101
JsonArray
102102
UuidArray
103103
Bool8Array
104+
VariableShapeTensorArray
104105

105106
.. _api.scalar:
106107

@@ -165,6 +166,7 @@ classes may expose data type-specific methods or properties.
165166
UnionScalar
166167
ExtensionScalar
167168
FixedShapeTensorScalar
169+
VariableShapeTensorScalar
168170
OpaqueScalar
169171
JsonScalar
170172
UuidScalar

docs/source/python/api/datatypes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ These should be used to create Arrow data types and schemas.
6868
dictionary
6969
run_end_encoded
7070
fixed_shape_tensor
71+
variable_shape_tensor
7172
union
7273
dense_union
7374
sparse_union
@@ -142,6 +143,7 @@ implemented by PyArrow.
142143
:toctree: ../generated/
143144

144145
FixedShapeTensorType
146+
VariableShapeTensorType
145147
OpaqueType
146148
JsonType
147149
UuidType

python/pyarrow/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def print_entry(label, value):
165165
dictionary,
166166
run_end_encoded,
167167
bool8, fixed_shape_tensor, json_, opaque, uuid,
168+
variable_shape_tensor,
168169
field,
169170
type_for_alias,
170171
DataType, DictionaryType, StructType,
@@ -178,6 +179,7 @@ def print_entry(label, value):
178179
RunEndEncodedType, Bool8Type, FixedShapeTensorType,
179180
JsonType, OpaqueType, UuidType,
180181
UnknownExtensionType,
182+
VariableShapeTensorType,
181183
register_extension_type, unregister_extension_type,
182184
DictionaryMemo,
183185
KeyValueMetadata,
@@ -214,6 +216,7 @@ def print_entry(label, value):
214216
StructArray, ExtensionArray,
215217
RunEndEncodedArray, Bool8Array, FixedShapeTensorArray,
216218
JsonArray, OpaqueArray, UuidArray,
219+
VariableShapeTensorArray,
217220
scalar, NA, _NULL as NULL, Scalar,
218221
NullScalar, BooleanScalar,
219222
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
@@ -231,7 +234,8 @@ def print_entry(label, value):
231234
FixedSizeBinaryScalar, DictionaryScalar,
232235
MapScalar, StructScalar, UnionScalar,
233236
RunEndEncodedScalar, Bool8Scalar, ExtensionScalar,
234-
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar)
237+
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar,
238+
VariableShapeTensorScalar)
235239

236240
# Buffers, allocation
237241
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,

0 commit comments

Comments
 (0)