Skip to content

Commit aff876a

Browse files
authored
GH-34796: [C++] Add FromTensor, ToTensor and strides methods to FixedShapeTensorArray (#34797)
### Rationale for this change We want to enable converting Tensors to FixedShapeTensorArrays and the other way around. ### What changes are included in this PR? This adds FromTensor, ToTensor to FixedShapeTensorArrays and strides method to FixedShapeTensorType. ### Are these changes tested? Yes. ### Are there any user-facing changes? This adds FromTensor, ToTensor and strides are user facing methods. * Closes: #34796 Authored-by: Rok Mihevc <rok@mihevc.org> Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent c40e658 commit aff876a

File tree

3 files changed

+430
-0
lines changed

3 files changed

+430
-0
lines changed

cpp/src/arrow/extension/fixed_shape_tensor.cc

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "arrow/array/array_nested.h"
2424
#include "arrow/array/array_primitive.h"
2525
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
26+
#include "arrow/tensor.h"
2627
#include "arrow/util/int_util_overflow.h"
2728
#include "arrow/util/logging.h"
2829
#include "arrow/util/sort.h"
@@ -33,8 +34,52 @@
3334
namespace rj = arrow::rapidjson;
3435

3536
namespace arrow {
37+
3638
namespace extension {
3739

40+
namespace {
41+
42+
Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
43+
const std::vector<int64_t>& permutation,
44+
std::vector<int64_t>* strides) {
45+
if (permutation.empty()) {
46+
return internal::ComputeRowMajorStrides(type, shape, strides);
47+
}
48+
49+
const int byte_width = type.byte_width();
50+
51+
int64_t remaining = 0;
52+
if (!shape.empty() && shape.front() > 0) {
53+
remaining = byte_width;
54+
for (auto i : permutation) {
55+
if (i > 0) {
56+
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
57+
return Status::Invalid(
58+
"Strides computed from shape would not fit in 64-bit integer");
59+
}
60+
}
61+
}
62+
}
63+
64+
if (remaining == 0) {
65+
strides->assign(shape.size(), byte_width);
66+
return Status::OK();
67+
}
68+
69+
strides->push_back(remaining);
70+
for (auto i : permutation) {
71+
if (i > 0) {
72+
remaining /= shape[i];
73+
strides->push_back(remaining);
74+
}
75+
}
76+
internal::Permute(permutation, strides);
77+
78+
return Status::OK();
79+
}
80+
81+
} // namespace
82+
3883
bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
3984
if (extension_name() != other.extension_name()) {
4085
return false;
@@ -140,6 +185,132 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
140185
return std::make_shared<ExtensionArray>(data);
141186
}
142187

188+
Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor(
189+
const std::shared_ptr<Tensor>& tensor) {
190+
auto permutation = internal::ArgSort(tensor->strides(), std::greater<>());
191+
if (permutation[0] != 0) {
192+
return Status::Invalid(
193+
"Only first-major tensors can be zero-copy converted to arrays");
194+
}
195+
permutation.erase(permutation.begin());
196+
197+
std::vector<int64_t> cell_shape;
198+
for (auto i : permutation) {
199+
cell_shape.emplace_back(tensor->shape()[i]);
200+
}
201+
202+
std::vector<std::string> dim_names;
203+
if (!tensor->dim_names().empty()) {
204+
for (auto i : permutation) {
205+
dim_names.emplace_back(tensor->dim_names()[i]);
206+
}
207+
}
208+
209+
for (int64_t& i : permutation) {
210+
--i;
211+
}
212+
213+
auto ext_type = internal::checked_pointer_cast<ExtensionType>(
214+
fixed_shape_tensor(tensor->type(), cell_shape, permutation, dim_names));
215+
216+
std::shared_ptr<Array> value_array;
217+
switch (tensor->type_id()) {
218+
case Type::UINT8: {
219+
value_array = std::make_shared<UInt8Array>(tensor->size(), tensor->data());
220+
break;
221+
}
222+
case Type::INT8: {
223+
value_array = std::make_shared<Int8Array>(tensor->size(), tensor->data());
224+
break;
225+
}
226+
case Type::UINT16: {
227+
value_array = std::make_shared<UInt16Array>(tensor->size(), tensor->data());
228+
break;
229+
}
230+
case Type::INT16: {
231+
value_array = std::make_shared<Int16Array>(tensor->size(), tensor->data());
232+
break;
233+
}
234+
case Type::UINT32: {
235+
value_array = std::make_shared<UInt32Array>(tensor->size(), tensor->data());
236+
break;
237+
}
238+
case Type::INT32: {
239+
value_array = std::make_shared<Int32Array>(tensor->size(), tensor->data());
240+
break;
241+
}
242+
case Type::UINT64: {
243+
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
244+
break;
245+
}
246+
case Type::INT64: {
247+
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
248+
break;
249+
}
250+
case Type::HALF_FLOAT: {
251+
value_array = std::make_shared<HalfFloatArray>(tensor->size(), tensor->data());
252+
break;
253+
}
254+
case Type::FLOAT: {
255+
value_array = std::make_shared<FloatArray>(tensor->size(), tensor->data());
256+
break;
257+
}
258+
case Type::DOUBLE: {
259+
value_array = std::make_shared<DoubleArray>(tensor->size(), tensor->data());
260+
break;
261+
}
262+
default: {
263+
return Status::NotImplemented("Unsupported tensor type: ",
264+
tensor->type()->ToString());
265+
}
266+
}
267+
auto cell_size = static_cast<int32_t>(tensor->size() / tensor->shape()[0]);
268+
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> arr,
269+
FixedSizeListArray::FromArrays(value_array, cell_size));
270+
std::shared_ptr<Array> ext_arr = ExtensionType::WrapArray(ext_type, arr);
271+
return std::reinterpret_pointer_cast<FixedShapeTensorArray>(ext_arr);
272+
}
273+
274+
const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
275+
// To convert an array of n dimensional tensors to a n+1 dimensional tensor we
276+
// interpret the array's length as the first dimension the new tensor.
277+
278+
auto ext_arr = internal::checked_pointer_cast<FixedSizeListArray>(this->storage());
279+
auto ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
280+
ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
281+
Status::Invalid(ext_arr->value_type()->ToString(),
282+
" is not valid data type for a tensor"));
283+
auto permutation = ext_type->permutation();
284+
285+
std::vector<std::string> dim_names;
286+
if (!ext_type->dim_names().empty()) {
287+
for (auto i : permutation) {
288+
dim_names.emplace_back(ext_type->dim_names()[i]);
289+
}
290+
dim_names.insert(dim_names.begin(), 1, "");
291+
} else {
292+
dim_names = {};
293+
}
294+
295+
std::vector<int64_t> shape;
296+
for (int64_t& i : permutation) {
297+
shape.emplace_back(ext_type->shape()[i]);
298+
++i;
299+
}
300+
shape.insert(shape.begin(), 1, this->length());
301+
permutation.insert(permutation.begin(), 1, 0);
302+
303+
std::vector<int64_t> tensor_strides;
304+
auto value_type = internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type());
305+
ARROW_RETURN_NOT_OK(
306+
ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides));
307+
ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten());
308+
ARROW_ASSIGN_OR_RAISE(
309+
auto tensor, Tensor::Make(ext_arr->value_type(), buffers->data()->buffers[1], shape,
310+
tensor_strides, dim_names));
311+
return tensor;
312+
}
313+
143314
Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
144315
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
145316
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
@@ -157,6 +328,17 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
157328
shape, permutation, dim_names);
158329
}
159330

331+
const std::vector<int64_t>& FixedShapeTensorType::strides() {
332+
if (strides_.empty()) {
333+
auto value_type = internal::checked_pointer_cast<FixedWidthType>(this->value_type_);
334+
std::vector<int64_t> tensor_strides;
335+
ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), this->permutation(),
336+
&tensor_strides));
337+
strides_ = tensor_strides;
338+
}
339+
return strides_;
340+
}
341+
160342
std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
161343
const std::vector<int64_t>& shape,
162344
const std::vector<int64_t>& permutation,

cpp/src/arrow/extension/fixed_shape_tensor.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@ namespace extension {
2323
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
2424
public:
2525
using ExtensionArray::ExtensionArray;
26+
27+
/// \brief Create a FixedShapeTensorArray from a Tensor
28+
///
29+
/// This method will create a FixedShapeTensorArray from a Tensor, taking its first
30+
/// dimension as the number of elements in the resulting array and the remaining
31+
/// dimensions as the shape of the individual tensors. If Tensor provides strides,
32+
/// they will be used to determine dimension permutation. Otherwise, row-major layout
33+
/// (i.e. no permutation) will be assumed.
34+
///
35+
/// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray
36+
static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
37+
const std::shared_ptr<Tensor>& tensor);
38+
39+
/// \brief Create a Tensor from FixedShapeTensorArray
40+
///
41+
/// This method will create a Tensor from a FixedShapeTensorArray, setting its first
42+
/// dimension as length equal to the FixedShapeTensorArray's length and the remaining
43+
/// dimensions as the FixedShapeTensorType's shape. Shape and dim_names will be
44+
/// permuted according to permutation stored in the FixedShapeTensorType metadata.
45+
const Result<std::shared_ptr<Tensor>> ToTensor() const;
2646
};
2747

2848
/// \brief Concrete type class for constant-size Tensor data.
@@ -51,6 +71,11 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
5171
/// Value type of tensor elements
5272
const std::shared_ptr<DataType> value_type() const { return value_type_; }
5373

74+
/// Strides of tensor elements. Strides state offset in bytes between adjacent
75+
/// elements along each dimension. In case permutation is non-empty strides are
76+
/// computed from permuted tensor element's shape.
77+
const std::vector<int64_t>& strides();
78+
5479
/// Permutation mapping from logical to physical memory layout of tensor elements
5580
const std::vector<int64_t>& permutation() const { return permutation_; }
5681

@@ -78,6 +103,7 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
78103
std::shared_ptr<DataType> storage_type_;
79104
std::shared_ptr<DataType> value_type_;
80105
std::vector<int64_t> shape_;
106+
std::vector<int64_t> strides_;
81107
std::vector<int64_t> permutation_;
82108
std::vector<std::string> dim_names_;
83109
};

0 commit comments

Comments
 (0)