Skip to content

Commit c71c905

Browse files
authored
add batch_query_array which flattens nested vec to two numpy arrays before returning to python (#35)
* add batch_query_array for numpy output instead of nested lists * add batch_query_array for numpy output instead of nested lists
1 parent 2721ebe commit c71c905

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

cpp/main.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ PYBIND11_MODULE(PRTree, m)
3333
.def("batch_query", &PRTree<T, B, 2>::find_all, R"pbdoc(
3434
parallel query with multi-thread
3535
)pbdoc")
36+
.def("batch_query_array", &PRTree<T, B, 2>::find_all_array, R"pbdoc(
37+
parallel query with multi-thread with array output
38+
)pbdoc")
3639
.def("erase", &PRTree<T, B, 2>::erase, R"pbdoc(
3740
Delete from prtree
3841
)pbdoc")
@@ -74,6 +77,9 @@ PYBIND11_MODULE(PRTree, m)
7477
.def("batch_query", &PRTree<T, B, 3>::find_all, R"pbdoc(
7578
parallel query with multi-thread
7679
)pbdoc")
80+
.def("batch_query_array", &PRTree<T, B, 3>::find_all_array, R"pbdoc(
81+
parallel query with multi-thread with array output
82+
)pbdoc")
7783
.def("erase", &PRTree<T, B, 3>::erase, R"pbdoc(
7884
Delete from prtree
7985
)pbdoc")
@@ -115,6 +121,9 @@ PYBIND11_MODULE(PRTree, m)
115121
.def("batch_query", &PRTree<T, B, 4>::find_all, R"pbdoc(
116122
parallel query with multi-thread
117123
)pbdoc")
124+
.def("batch_query_array", &PRTree<T, B, 4>::find_all_array, R"pbdoc(
125+
parallel query with multi-thread with array output
126+
)pbdoc")
118127
.def("erase", &PRTree<T, B, 4>::erase, R"pbdoc(
119128
Delete from prtree
120129
)pbdoc")

cpp/prtree.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,38 @@ namespace py = pybind11;
5151
template <class T>
5252
using vec = std::vector<T>;
5353

54+
template <typename Sequence >
55+
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence& seq) {
56+
57+
auto size = seq.size();
58+
auto data = seq.data();
59+
std::unique_ptr<Sequence> seq_ptr = std::make_unique<Sequence>(std::move(seq));
60+
auto capsule = py::capsule(seq_ptr.get(), [](void *p) { std::unique_ptr<Sequence>(reinterpret_cast<Sequence*>(p)); });
61+
seq_ptr.release();
62+
return py::array(size, data, capsule);
63+
}
64+
65+
template <typename T>
66+
auto list_list_to_arrays(vec<vec<T>> out_ll){
67+
vec<T> out_s;
68+
out_s.reserve(out_ll.size());
69+
std::size_t sum = 0;
70+
for (auto &&i : out_ll) {
71+
out_s.push_back(i.size());
72+
sum += i.size();
73+
}
74+
vec<T> out;
75+
out.reserve(sum);
76+
for(const auto &v: out_ll)
77+
out.insert(out.end(), v.begin(), v.end());
78+
79+
return make_tuple(
80+
std::move(as_pyarray(out_s))
81+
,
82+
std::move(as_pyarray(out))
83+
);
84+
}
85+
5486
template <class T, size_t StaticCapacity>
5587
using svec = itlib::small_vector<T, StaticCapacity>;
5688

@@ -1259,6 +1291,10 @@ class PRTree
12591291
return out;
12601292
}
12611293

1294+
auto find_all_array(const py::array_t<float> &x){
1295+
return list_list_to_arrays(std::move(find_all(x)));
1296+
}
1297+
12621298
auto find_one(const vec<float> &x)
12631299
{
12641300
bool is_point = false;

0 commit comments

Comments
 (0)