2323#include " arrow/array/array_nested.h"
2424#include " arrow/array/array_primitive.h"
2525#include " arrow/json/rapidjson_defs.h" // IWYU pragma: keep
26+ #include " arrow/tensor.h"
2627#include " arrow/util/int_util_overflow.h"
2728#include " arrow/util/logging.h"
2829#include " arrow/util/sort.h"
3334namespace rj = arrow::rapidjson;
3435
3536namespace arrow {
37+
3638namespace extension {
3739
40+ namespace {
41+
42+ Status ComputeStrides (const FixedWidthType& type, const std::vector<int64_t >& shape,
43+ const std::vector<int64_t >& permutation,
44+ std::vector<int64_t >* strides) {
45+ if (permutation.empty ()) {
46+ return internal::ComputeRowMajorStrides (type, shape, strides);
47+ }
48+
49+ const int byte_width = type.byte_width ();
50+
51+ int64_t remaining = 0 ;
52+ if (!shape.empty () && shape.front () > 0 ) {
53+ remaining = byte_width;
54+ for (auto i : permutation) {
55+ if (i > 0 ) {
56+ if (internal::MultiplyWithOverflow (remaining, shape[i], &remaining)) {
57+ return Status::Invalid (
58+ " Strides computed from shape would not fit in 64-bit integer" );
59+ }
60+ }
61+ }
62+ }
63+
64+ if (remaining == 0 ) {
65+ strides->assign (shape.size (), byte_width);
66+ return Status::OK ();
67+ }
68+
69+ strides->push_back (remaining);
70+ for (auto i : permutation) {
71+ if (i > 0 ) {
72+ remaining /= shape[i];
73+ strides->push_back (remaining);
74+ }
75+ }
76+ internal::Permute (permutation, strides);
77+
78+ return Status::OK ();
79+ }
80+
81+ } // namespace
82+
3883bool FixedShapeTensorType::ExtensionEquals (const ExtensionType& other) const {
3984 if (extension_name () != other.extension_name ()) {
4085 return false ;
@@ -140,6 +185,132 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
140185 return std::make_shared<ExtensionArray>(data);
141186}
142187
188+ Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor (
189+ const std::shared_ptr<Tensor>& tensor) {
190+ auto permutation = internal::ArgSort (tensor->strides (), std::greater<>());
191+ if (permutation[0 ] != 0 ) {
192+ return Status::Invalid (
193+ " Only first-major tensors can be zero-copy converted to arrays" );
194+ }
195+ permutation.erase (permutation.begin ());
196+
197+ std::vector<int64_t > cell_shape;
198+ for (auto i : permutation) {
199+ cell_shape.emplace_back (tensor->shape ()[i]);
200+ }
201+
202+ std::vector<std::string> dim_names;
203+ if (!tensor->dim_names ().empty ()) {
204+ for (auto i : permutation) {
205+ dim_names.emplace_back (tensor->dim_names ()[i]);
206+ }
207+ }
208+
209+ for (int64_t & i : permutation) {
210+ --i;
211+ }
212+
213+ auto ext_type = internal::checked_pointer_cast<ExtensionType>(
214+ fixed_shape_tensor (tensor->type (), cell_shape, permutation, dim_names));
215+
216+ std::shared_ptr<Array> value_array;
217+ switch (tensor->type_id ()) {
218+ case Type::UINT8: {
219+ value_array = std::make_shared<UInt8Array>(tensor->size (), tensor->data ());
220+ break ;
221+ }
222+ case Type::INT8: {
223+ value_array = std::make_shared<Int8Array>(tensor->size (), tensor->data ());
224+ break ;
225+ }
226+ case Type::UINT16: {
227+ value_array = std::make_shared<UInt16Array>(tensor->size (), tensor->data ());
228+ break ;
229+ }
230+ case Type::INT16: {
231+ value_array = std::make_shared<Int16Array>(tensor->size (), tensor->data ());
232+ break ;
233+ }
234+ case Type::UINT32: {
235+ value_array = std::make_shared<UInt32Array>(tensor->size (), tensor->data ());
236+ break ;
237+ }
238+ case Type::INT32: {
239+ value_array = std::make_shared<Int32Array>(tensor->size (), tensor->data ());
240+ break ;
241+ }
242+ case Type::UINT64: {
243+ value_array = std::make_shared<Int64Array>(tensor->size (), tensor->data ());
244+ break ;
245+ }
246+ case Type::INT64: {
247+ value_array = std::make_shared<Int64Array>(tensor->size (), tensor->data ());
248+ break ;
249+ }
250+ case Type::HALF_FLOAT: {
251+ value_array = std::make_shared<HalfFloatArray>(tensor->size (), tensor->data ());
252+ break ;
253+ }
254+ case Type::FLOAT: {
255+ value_array = std::make_shared<FloatArray>(tensor->size (), tensor->data ());
256+ break ;
257+ }
258+ case Type::DOUBLE: {
259+ value_array = std::make_shared<DoubleArray>(tensor->size (), tensor->data ());
260+ break ;
261+ }
262+ default : {
263+ return Status::NotImplemented (" Unsupported tensor type: " ,
264+ tensor->type ()->ToString ());
265+ }
266+ }
267+ auto cell_size = static_cast <int32_t >(tensor->size () / tensor->shape ()[0 ]);
268+ ARROW_ASSIGN_OR_RAISE (std::shared_ptr<Array> arr,
269+ FixedSizeListArray::FromArrays (value_array, cell_size));
270+ std::shared_ptr<Array> ext_arr = ExtensionType::WrapArray (ext_type, arr);
271+ return std::reinterpret_pointer_cast<FixedShapeTensorArray>(ext_arr);
272+ }
273+
274+ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor () const {
275+ // To convert an array of n dimensional tensors to a n+1 dimensional tensor we
276+ // interpret the array's length as the first dimension the new tensor.
277+
278+ auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>(this ->storage ());
279+ auto ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(this ->type ());
280+ ARROW_RETURN_IF (!is_fixed_width (*ext_arr->value_type ()),
281+ Status::Invalid (ext_arr->value_type ()->ToString (),
282+ " is not valid data type for a tensor" ));
283+ auto permutation = ext_type->permutation ();
284+
285+ std::vector<std::string> dim_names;
286+ if (!ext_type->dim_names ().empty ()) {
287+ for (auto i : permutation) {
288+ dim_names.emplace_back (ext_type->dim_names ()[i]);
289+ }
290+ dim_names.insert (dim_names.begin (), 1 , " " );
291+ } else {
292+ dim_names = {};
293+ }
294+
295+ std::vector<int64_t > shape;
296+ for (int64_t & i : permutation) {
297+ shape.emplace_back (ext_type->shape ()[i]);
298+ ++i;
299+ }
300+ shape.insert (shape.begin (), 1 , this ->length ());
301+ permutation.insert (permutation.begin (), 1 , 0 );
302+
303+ std::vector<int64_t > tensor_strides;
304+ auto value_type = internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type ());
305+ ARROW_RETURN_NOT_OK (
306+ ComputeStrides (*value_type.get (), shape, permutation, &tensor_strides));
307+ ARROW_ASSIGN_OR_RAISE (auto buffers, ext_arr->Flatten ());
308+ ARROW_ASSIGN_OR_RAISE (
309+ auto tensor, Tensor::Make (ext_arr->value_type (), buffers->data ()->buffers [1 ], shape,
310+ tensor_strides, dim_names));
311+ return tensor;
312+ }
313+
143314Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make (
144315 const std::shared_ptr<DataType>& value_type, const std::vector<int64_t >& shape,
145316 const std::vector<int64_t >& permutation, const std::vector<std::string>& dim_names) {
@@ -157,6 +328,17 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
157328 shape, permutation, dim_names);
158329}
159330
331+ const std::vector<int64_t >& FixedShapeTensorType::strides () {
332+ if (strides_.empty ()) {
333+ auto value_type = internal::checked_pointer_cast<FixedWidthType>(this ->value_type_ );
334+ std::vector<int64_t > tensor_strides;
335+ ARROW_CHECK_OK (ComputeStrides (*value_type.get (), this ->shape (), this ->permutation (),
336+ &tensor_strides));
337+ strides_ = tensor_strides;
338+ }
339+ return strides_;
340+ }
341+
160342std::shared_ptr<DataType> fixed_shape_tensor (const std::shared_ptr<DataType>& value_type,
161343 const std::vector<int64_t >& shape,
162344 const std::vector<int64_t >& permutation,
0 commit comments