|
| 1 | + |
| 2 | +/*! \file thrust/zip_function.h |
| 3 | + * \brief Adaptor type that turns an N-ary function object into one that takes |
| 4 | + * a tuple of size N so it can easily be used with algorithms taking zip |
| 5 | + * iterators |
| 6 | + */ |
| 7 | + |
| 8 | +#pragma once |
| 9 | + |
| 10 | +#include <thrust/detail/config.h> |
| 11 | +#include <thrust/detail/cpp11_required.h> |
| 12 | +#include <thrust/detail/modern_gcc_required.h> |
| 13 | + |
| 14 | +#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC) |
| 15 | + |
| 16 | +#include <thrust/type_traits/integer_sequence.h> |
| 17 | +#include <thrust/detail/type_deduction.h> |
| 18 | + |
| 19 | +THRUST_BEGIN_NS |
| 20 | + |
| 21 | +/*! \addtogroup function_objects Function Objects |
| 22 | + * \{ |
| 23 | + */ |
| 24 | + |
| 25 | +/*! \addtogroup function_object_adaptors Function Object Adaptors |
| 26 | + * \ingroup function_objects |
| 27 | + * \{ |
| 28 | + */ |
| 29 | + |
| 30 | +namespace detail { |
| 31 | +namespace zip_detail { |
| 32 | + |
| 33 | +// Add workaround for decltype(auto) on C++11-only compilers: |
| 34 | +#if THRUST_CPP_DIALECT >= 2014 |
| 35 | + |
| 36 | +template <typename Function, typename Tuple, std::size_t... Is> |
| 37 | +__host__ __device__ |
| 38 | +decltype(auto) apply_impl(Function&& func, Tuple&& args, index_sequence<Is...>) |
| 39 | +{ |
| 40 | + return func(thrust::get<Is>(THRUST_FWD(args))...); |
| 41 | +} |
| 42 | + |
| 43 | +template <typename Function, typename Tuple> |
| 44 | +__host__ __device__ |
| 45 | +decltype(auto) apply(Function&& func, Tuple&& args) |
| 46 | +{ |
| 47 | + constexpr auto tuple_size = thrust::tuple_size<typename std::decay<Tuple>::type>::value; |
| 48 | + return apply_impl(THRUST_FWD(func), THRUST_FWD(args), make_index_sequence<tuple_size>{}); |
| 49 | +} |
| 50 | + |
| 51 | +#else // THRUST_CPP_DIALECT |
| 52 | + |
| 53 | +template <typename Function, typename Tuple, std::size_t... Is> |
| 54 | +__host__ __device__ |
| 55 | +auto apply_impl(Function&& func, Tuple&& args, index_sequence<Is...>) |
| 56 | +THRUST_DECLTYPE_RETURNS(func(thrust::get<Is>(THRUST_FWD(args))...)) |
| 57 | + |
| 58 | +template <typename Function, typename Tuple> |
| 59 | +__host__ __device__ |
| 60 | +auto apply(Function&& func, Tuple&& args) |
| 61 | +THRUST_DECLTYPE_RETURNS( |
| 62 | + apply_impl( |
| 63 | + THRUST_FWD(func), |
| 64 | + THRUST_FWD(args), |
| 65 | + make_index_sequence< |
| 66 | + thrust::tuple_size<typename std::decay<Tuple>::type>::value>{}) |
| 67 | +) |
| 68 | + |
| 69 | +#endif // THRUST_CPP_DIALECT |
| 70 | + |
| 71 | +} // namespace zip_detail |
| 72 | +} // namespace detail |
| 73 | + |
| 74 | +/*! \p zip_function is a function object that allows the easy use of N-ary |
| 75 | + * function objects with \p zip_iterators without redefining them to take a |
| 76 | + * \p tuple instead of N arguments. |
| 77 | + * |
| 78 | + * This means that if a functor that takes 2 arguments which could be used with |
| 79 | + * the \p transform function and \p device_iterators can be extended to take 3 |
| 80 | + * arguments and \p zip_iterators without rewriting the functor in terms of |
| 81 | + * \p tuple. |
| 82 | + * |
| 83 | + * The \p make_zip_function convenience function is provided to avoid having |
| 84 | + * to explicitely define the type of the functor when creating a \p zip_function, |
| 85 | + * whic is especially helpful when using lambdas as the functor. |
| 86 | + * |
| 87 | + * \code |
| 88 | + * #include <thrust/iterator/zip_iterator.h> |
| 89 | + * #include <thrust/device_vector.h> |
| 90 | + * #include <thrust/transform.h> |
| 91 | + * #include <thrust/zip_function.h> |
| 92 | + * |
| 93 | + * struct SumTuple { |
| 94 | + * float operator()(Tuple tup) { |
| 95 | + * return std::get<0>(tup) + std::get<1>(tup) + std::get<2>(tup); |
| 96 | + * } |
| 97 | + * }; |
| 98 | + * struct SumArgs { |
| 99 | + * float operator()(float a, float b, float c) { |
| 100 | + * return a + b + c; |
| 101 | + * } |
| 102 | + * }; |
| 103 | + * |
| 104 | + * int main() { |
| 105 | + * thrust::device_vector<float> A(3); |
| 106 | + * thrust::device_vector<float> B(3); |
| 107 | + * thrust::device_vector<float> C(3); |
| 108 | + * thrust::device_vector<float> D(3); |
| 109 | + * A[0] = 0.f; A[1] = 1.f; A[2] = 2.f; |
| 110 | + * B[0] = 1.f; B[1] = 2.f; B[2] = 3.f; |
| 111 | + * C[0] = 2.f; C[1] = 3.f; C[2] = 4.f; |
| 112 | + * |
| 113 | + * // The following four invocations of transform are equivalent |
| 114 | + * // Transform with 3-tuple |
| 115 | + * thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())), |
| 116 | + * thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())), |
| 117 | + * D.begin(), |
| 118 | + * SumTuple{}); |
| 119 | + * |
| 120 | + * // Transform with 3 parameters |
| 121 | + * thrust::zip_function<SumArgs> adapted{}; |
| 122 | + * thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())), |
| 123 | + * thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())), |
| 124 | + * D.begin(), |
| 125 | + * adapted); |
| 126 | + * |
| 127 | + * // Transform with 3 parameters with convenience function |
| 128 | + * thrust::zip_function<SumArgs> adapted{}; |
| 129 | + * thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())), |
| 130 | + * thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())), |
| 131 | + * D.begin(), |
| 132 | + * thrust::make_zip_function(SumArgs{})); |
| 133 | + * |
| 134 | + * // Transform with 3 parameters with convenience function and lambda |
| 135 | + * thrust::zip_function<SumArgs> adapted{}; |
| 136 | + * thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())), |
| 137 | + * thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())), |
| 138 | + * D.begin(), |
| 139 | + * thrust::make_zip_function([] (float a, float b, float c) { |
| 140 | + * return a + b + c; |
| 141 | + * })); |
| 142 | + * return 0; |
| 143 | + * } |
| 144 | + * \endcode |
| 145 | + * |
| 146 | + * \see make_zip_function |
| 147 | + * \see zip_iterator |
| 148 | + */ |
| 149 | +template <typename Function> |
| 150 | +class zip_function |
| 151 | +{ |
| 152 | + public: |
| 153 | + __host__ __device__ |
| 154 | + zip_function(Function func) : func(std::move(func)) {} |
| 155 | + |
| 156 | +// Add workaround for decltype(auto) on C++11-only compilers: |
| 157 | +#if THRUST_CPP_DIALECT >= 2014 |
| 158 | + |
| 159 | + template <typename Tuple> |
| 160 | + __host__ __device__ |
| 161 | + decltype(auto) operator()(Tuple&& args) const |
| 162 | + { |
| 163 | + return detail::zip_detail::apply(func, THRUST_FWD(args)); |
| 164 | + } |
| 165 | + |
| 166 | +#else // THRUST_CPP_DIALECT |
| 167 | + |
| 168 | + // Can't just use THRUST_DECLTYPE_RETURNS here since we need to use |
| 169 | + // std::declval for the signature components: |
| 170 | + template <typename Tuple> |
| 171 | + __host__ __device__ |
| 172 | + auto operator()(Tuple&& args) const |
| 173 | + noexcept(noexcept(detail::zip_detail::apply(std::declval<Function>(), THRUST_FWD(args)))) |
| 174 | + -> decltype(detail::zip_detail::apply(std::declval<Function>(), THRUST_FWD(args))) |
| 175 | + |
| 176 | + { |
| 177 | + return detail::zip_detail::apply(func, THRUST_FWD(args)); |
| 178 | + } |
| 179 | + |
| 180 | +#endif // THRUST_CPP_DIALECT |
| 181 | + |
| 182 | + private: |
| 183 | + mutable Function func; |
| 184 | +}; |
| 185 | + |
| 186 | +/*! \p make_zip_function creates a \p zip_function from a function object. |
| 187 | + * |
| 188 | + * \param fun The N-ary function object. |
| 189 | + * \return A \p zip_function that takes a N-tuple. |
| 190 | + * |
| 191 | + * \see zip_function |
| 192 | + */ |
| 193 | +template <typename Function> |
| 194 | +__host__ __device__ |
| 195 | +auto make_zip_function(Function&& fun) -> zip_function<typename std::decay<Function>::type> |
| 196 | +{ |
| 197 | + using func_t = typename std::decay<Function>::type; |
| 198 | + return zip_function<func_t>(THRUST_FWD(fun)); |
| 199 | +} |
| 200 | + |
| 201 | +/*! \} // end function_object_adaptors |
| 202 | + */ |
| 203 | + |
| 204 | +/*! \} // end function_objects |
| 205 | + */ |
| 206 | + |
| 207 | +THRUST_END_NS |
| 208 | + |
| 209 | +#endif |
0 commit comments