Skip to content

Commit 82370c2

Browse files
bjudealliepiper
authored andcommitted
Add zip_function to adapt N-ary functions to take a tuple
Eases the use general function objects with zip iterators without modifying them or hand writing a wrapping class Test for zip_function Based on the zip iterator transform test zip_function: Move details into thrust::detal::zip_detail zip_function: make operator() const and make stored function mutable CMake: Add filter for test that require c++11 Only add zip_function for now, making the list exhaustive can be another PR zip_function: Add example to arbitrary_transformation zip_function: Add c++11 guard zip_function: Documentation Zip Function: newline at end of file Allison rewrote some bits to support C++11 compilers. Reviewed-by: Allison Vacanti <[email protected]>
1 parent 3482631 commit 82370c2

File tree

3 files changed

+320
-6
lines changed

3 files changed

+320
-6
lines changed

examples/arbitrary_transformation.cu

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
#include <thrust/iterator/zip_iterator.h>
44
#include <iostream>
55

6+
#include <thrust/detail/config.h>
7+
8+
#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC)
9+
#include <thrust/zip_function.h>
10+
#endif // >= C++11
11+
612
// This example shows how to implement an arbitrary transformation of
713
// the form output[i] = F(first[i], second[i], third[i], ... ).
814
// In this example, we use a function with 3 inputs and 1 output.
@@ -22,6 +28,10 @@
2228
// D[i] = A[i] + B[i] * C[i];
2329
// by invoking arbitrary_functor() on each of the tuples using for_each.
2430
//
31+
// If we are using a functor that is not designed for zip iterators by taking a
32+
// tuple instead of individual arguments we can adapt this function using the
33+
// zip_function adaptor (C++11 only).
34+
//
2535
// Note that we could extend this example to implement functions with an
2636
// arbitrary number of input arguments by zipping more sequence together.
2737
// With the same approach we can have multiple *output* sequences, if we
@@ -31,7 +41,7 @@
3141
//
3242
// The possibilities are endless! :)
3343

34-
struct arbitrary_functor
44+
struct arbitrary_functor1
3545
{
3646
template <typename Tuple>
3747
__host__ __device__
@@ -42,14 +52,25 @@ struct arbitrary_functor
4252
}
4353
};
4454

55+
#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC)
56+
struct arbitrary_functor2
57+
{
58+
__host__ __device__
59+
void operator()(const float& a, const float& b, const float& c, float& d)
60+
{
61+
// D[i] = A[i] + B[i] * C[i];
62+
d = a + b * c;
63+
}
64+
};
65+
#endif // >= C++11
4566

4667
int main(void)
4768
{
4869
// allocate storage
4970
thrust::device_vector<float> A(5);
5071
thrust::device_vector<float> B(5);
5172
thrust::device_vector<float> C(5);
52-
thrust::device_vector<float> D(5);
73+
thrust::device_vector<float> D1(5);
5374

5475
// initialize input vectors
5576
A[0] = 3; B[0] = 6; C[0] = 2;
@@ -59,12 +80,26 @@ int main(void)
5980
A[4] = 2; B[4] = 8; C[4] = 3;
6081

6182
// apply the transformation
62-
thrust::for_each(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin(), D.begin())),
63-
thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end(), D.end())),
64-
arbitrary_functor());
83+
thrust::for_each(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin(), D1.begin())),
84+
thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end(), D1.end())),
85+
arbitrary_functor1());
86+
87+
// print the output
88+
std::cout << "Tuple functor" << std::endl;
89+
for(int i = 0; i < 5; i++)
90+
std::cout << A[i] << " + " << B[i] << " * " << C[i] << " = " << D1[i] << std::endl;
91+
92+
// apply the transformation using zip_function
93+
#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC)
94+
thrust::device_vector<float> D2(5);
95+
thrust::for_each(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin(), D2.begin())),
96+
thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end(), D2.end())),
97+
thrust::make_zip_function(arbitrary_functor2()));
6598

6699
// print the output
100+
std::cout << "N-ary functor" << std::endl;
67101
for(int i = 0; i < 5; i++)
68-
std::cout << A[i] << " + " << B[i] << " * " << C[i] << " = " << D[i] << std::endl;
102+
std::cout << A[i] << " + " << B[i] << " * " << C[i] << " = " << D2[i] << std::endl;
103+
#endif // >= C++11
69104
}
70105

testing/zip_function.cu

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include <thrust/detail/config.h>
2+
3+
#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC)
4+
5+
#include <unittest/unittest.h>
6+
#include <thrust/iterator/zip_iterator.h>
7+
#include <thrust/transform.h>
8+
#include <thrust/zip_function.h>
9+
10+
#include <iostream>
11+
12+
using namespace unittest;
13+
14+
struct SumThree
15+
{
16+
template <typename T1, typename T2, typename T3>
17+
__host__ __device__
18+
auto operator()(T1 x, T2 y, T3 z) const
19+
THRUST_DECLTYPE_RETURNS(x + y + z)
20+
}; // end SumThree
21+
22+
struct SumThreeTuple
23+
{
24+
template <typename Tuple>
25+
__host__ __device__
26+
auto operator()(Tuple x) const
27+
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
28+
}; // end SumThreeTuple
29+
30+
template <typename T>
31+
struct TestZipFunctionTransform
32+
{
33+
void operator()(const size_t n)
34+
{
35+
using namespace thrust;
36+
37+
host_vector<T> h_data0 = unittest::random_samples<T>(n);
38+
host_vector<T> h_data1 = unittest::random_samples<T>(n);
39+
host_vector<T> h_data2 = unittest::random_samples<T>(n);
40+
41+
device_vector<T> d_data0 = h_data0;
42+
device_vector<T> d_data1 = h_data1;
43+
device_vector<T> d_data2 = h_data2;
44+
45+
host_vector<T> h_result_tuple(n);
46+
host_vector<T> h_result_zip(n);
47+
device_vector<T> d_result_zip(n);
48+
49+
// Tuple base case
50+
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
51+
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
52+
h_result_tuple.begin(),
53+
SumThreeTuple{});
54+
// Zip Function
55+
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
56+
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
57+
h_result_zip.begin(),
58+
make_zip_function(SumThree{}));
59+
transform(make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin(), d_data2.begin())),
60+
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
61+
d_result_zip.begin(),
62+
make_zip_function(SumThree{}));
63+
64+
ASSERT_EQUAL(h_result_tuple, h_result_zip);
65+
ASSERT_EQUAL(h_result_tuple, d_result_zip);
66+
}
67+
};
68+
VariableUnitTest<TestZipFunctionTransform, ThirtyTwoBitTypes> TestZipFunctionTransformInstance;
69+
70+
#endif // THRUST_CPP_DIALECT

thrust/zip_function.h

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
2+
/*! \file thrust/zip_function.h
3+
* \brief Adaptor type that turns an N-ary function object into one that takes
4+
* a tuple of size N so it can easily be used with algorithms taking zip
5+
* iterators
6+
*/
7+
8+
#pragma once
9+
10+
#include <thrust/detail/config.h>
11+
#include <thrust/detail/cpp11_required.h>
12+
#include <thrust/detail/modern_gcc_required.h>
13+
14+
#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC)
15+
16+
#include <thrust/type_traits/integer_sequence.h>
17+
#include <thrust/detail/type_deduction.h>
18+
19+
THRUST_BEGIN_NS
20+
21+
/*! \addtogroup function_objects Function Objects
22+
* \{
23+
*/
24+
25+
/*! \addtogroup function_object_adaptors Function Object Adaptors
26+
* \ingroup function_objects
27+
* \{
28+
*/
29+
30+
namespace detail {
31+
namespace zip_detail {
32+
33+
// Add workaround for decltype(auto) on C++11-only compilers:
34+
#if THRUST_CPP_DIALECT >= 2014
35+
36+
template <typename Function, typename Tuple, std::size_t... Is>
37+
__host__ __device__
38+
decltype(auto) apply_impl(Function&& func, Tuple&& args, index_sequence<Is...>)
39+
{
40+
return func(thrust::get<Is>(THRUST_FWD(args))...);
41+
}
42+
43+
template <typename Function, typename Tuple>
44+
__host__ __device__
45+
decltype(auto) apply(Function&& func, Tuple&& args)
46+
{
47+
constexpr auto tuple_size = thrust::tuple_size<typename std::decay<Tuple>::type>::value;
48+
return apply_impl(THRUST_FWD(func), THRUST_FWD(args), make_index_sequence<tuple_size>{});
49+
}
50+
51+
#else // THRUST_CPP_DIALECT
52+
53+
template <typename Function, typename Tuple, std::size_t... Is>
54+
__host__ __device__
55+
auto apply_impl(Function&& func, Tuple&& args, index_sequence<Is...>)
56+
THRUST_DECLTYPE_RETURNS(func(thrust::get<Is>(THRUST_FWD(args))...))
57+
58+
template <typename Function, typename Tuple>
59+
__host__ __device__
60+
auto apply(Function&& func, Tuple&& args)
61+
THRUST_DECLTYPE_RETURNS(
62+
apply_impl(
63+
THRUST_FWD(func),
64+
THRUST_FWD(args),
65+
make_index_sequence<
66+
thrust::tuple_size<typename std::decay<Tuple>::type>::value>{})
67+
)
68+
69+
#endif // THRUST_CPP_DIALECT
70+
71+
} // namespace zip_detail
72+
} // namespace detail
73+
74+
/*! \p zip_function is a function object that allows the easy use of N-ary
75+
* function objects with \p zip_iterators without redefining them to take a
76+
* \p tuple instead of N arguments.
77+
*
78+
* This means that if a functor that takes 2 arguments which could be used with
79+
* the \p transform function and \p device_iterators can be extended to take 3
80+
* arguments and \p zip_iterators without rewriting the functor in terms of
81+
* \p tuple.
82+
*
83+
* The \p make_zip_function convenience function is provided to avoid having
84+
* to explicitely define the type of the functor when creating a \p zip_function,
85+
* whic is especially helpful when using lambdas as the functor.
86+
*
87+
* \code
88+
* #include <thrust/iterator/zip_iterator.h>
89+
* #include <thrust/device_vector.h>
90+
* #include <thrust/transform.h>
91+
* #include <thrust/zip_function.h>
92+
*
93+
* struct SumTuple {
94+
* float operator()(Tuple tup) {
95+
* return std::get<0>(tup) + std::get<1>(tup) + std::get<2>(tup);
96+
* }
97+
* };
98+
* struct SumArgs {
99+
* float operator()(float a, float b, float c) {
100+
* return a + b + c;
101+
* }
102+
* };
103+
*
104+
* int main() {
105+
* thrust::device_vector<float> A(3);
106+
* thrust::device_vector<float> B(3);
107+
* thrust::device_vector<float> C(3);
108+
* thrust::device_vector<float> D(3);
109+
* A[0] = 0.f; A[1] = 1.f; A[2] = 2.f;
110+
* B[0] = 1.f; B[1] = 2.f; B[2] = 3.f;
111+
* C[0] = 2.f; C[1] = 3.f; C[2] = 4.f;
112+
*
113+
* // The following four invocations of transform are equivalent
114+
* // Transform with 3-tuple
115+
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
116+
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
117+
* D.begin(),
118+
* SumTuple{});
119+
*
120+
* // Transform with 3 parameters
121+
* thrust::zip_function<SumArgs> adapted{};
122+
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
123+
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
124+
* D.begin(),
125+
* adapted);
126+
*
127+
* // Transform with 3 parameters with convenience function
128+
* thrust::zip_function<SumArgs> adapted{};
129+
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
130+
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
131+
* D.begin(),
132+
* thrust::make_zip_function(SumArgs{}));
133+
*
134+
* // Transform with 3 parameters with convenience function and lambda
135+
* thrust::zip_function<SumArgs> adapted{};
136+
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
137+
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
138+
* D.begin(),
139+
* thrust::make_zip_function([] (float a, float b, float c) {
140+
* return a + b + c;
141+
* }));
142+
* return 0;
143+
* }
144+
* \endcode
145+
*
146+
* \see make_zip_function
147+
* \see zip_iterator
148+
*/
149+
template <typename Function>
150+
class zip_function
151+
{
152+
public:
153+
__host__ __device__
154+
zip_function(Function func) : func(std::move(func)) {}
155+
156+
// Add workaround for decltype(auto) on C++11-only compilers:
157+
#if THRUST_CPP_DIALECT >= 2014
158+
159+
template <typename Tuple>
160+
__host__ __device__
161+
decltype(auto) operator()(Tuple&& args) const
162+
{
163+
return detail::zip_detail::apply(func, THRUST_FWD(args));
164+
}
165+
166+
#else // THRUST_CPP_DIALECT
167+
168+
// Can't just use THRUST_DECLTYPE_RETURNS here since we need to use
169+
// std::declval for the signature components:
170+
template <typename Tuple>
171+
__host__ __device__
172+
auto operator()(Tuple&& args) const
173+
noexcept(noexcept(detail::zip_detail::apply(std::declval<Function>(), THRUST_FWD(args))))
174+
-> decltype(detail::zip_detail::apply(std::declval<Function>(), THRUST_FWD(args)))
175+
176+
{
177+
return detail::zip_detail::apply(func, THRUST_FWD(args));
178+
}
179+
180+
#endif // THRUST_CPP_DIALECT
181+
182+
private:
183+
mutable Function func;
184+
};
185+
186+
/*! \p make_zip_function creates a \p zip_function from a function object.
187+
*
188+
* \param fun The N-ary function object.
189+
* \return A \p zip_function that takes a N-tuple.
190+
*
191+
* \see zip_function
192+
*/
193+
template <typename Function>
194+
__host__ __device__
195+
auto make_zip_function(Function&& fun) -> zip_function<typename std::decay<Function>::type>
196+
{
197+
using func_t = typename std::decay<Function>::type;
198+
return zip_function<func_t>(THRUST_FWD(fun));
199+
}
200+
201+
/*! \} // end function_object_adaptors
202+
*/
203+
204+
/*! \} // end function_objects
205+
*/
206+
207+
THRUST_END_NS
208+
209+
#endif

0 commit comments

Comments
 (0)