@@ -57,18 +57,16 @@ struct PacketizeJointMatrix {
57
57
58
58
/* ! @brief Performs a coalesced non-vectorized load when the current block is
59
59
* not internal.
60
- * @tparam trans Whether the source matrix is transposed or not.
61
60
* @tparam internal True if the current block is internal and no bounds
62
61
* checking is required.
63
- * @tparam ld The leading dimension of the destination memory.
64
62
*/
65
63
66
- template <bool trans, bool internal, int ld , typename SrcPointerType ,
67
- typename DestPointerType, typename EdgePredicate>
64
+ template <bool internal, typename SrcPointerType , typename DestPointerType ,
65
+ typename EdgePredicate>
68
66
static PORTBLAS_INLINE typename std::enable_if<!internal>::type load (
69
67
const bool in_range, SrcPointerType src, DestPointerType dest,
70
68
EdgePredicate) {
71
- value_t val = in_range ? *( src) : value_t {0 };
69
+ value_t val = in_range ? *src : value_t {0 };
72
70
using address_t = cl::sycl::access::address_space;
73
71
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
74
72
address_t ::local_space>,
@@ -79,93 +77,96 @@ struct PacketizeJointMatrix {
79
77
cl::sycl::ext::oneapi::bfloat16,
80
78
address_t ::local_space>,
81
79
DestPointerType>::value) {
82
- using dtype = cl::sycl::ext::oneapi::bfloat16 ;
83
- *dest = static_cast <dtype> (val);
80
+ using namespace cl ::sycl::ext::oneapi;
81
+ *dest = bfloat16 (val);
84
82
} else {
85
83
using namespace cl ::sycl::ext::oneapi::experimental::matrix;
86
84
*dest = round_to_tf32 (val);
87
85
}
88
86
}
87
+
89
88
/* ! @brief Performs a vectorised load using sycl::vec::load when the current
90
89
* block is internal. In the case where k < the
91
90
* number of elements being loaded then edge loads will be element wise with
92
91
* additional bounds checking.
93
- * @tparam trans Whether the source matrix is transposed or not.
94
92
* @tparam internal True if the current block is internal and no bounds
95
93
* checking is required.
96
- * @tparam ld The leading dimension of the destination memory. * /
97
- template <bool trans, bool internal, index_t ld , typename SrcPointerType ,
98
- typename DestPointerType, typename EdgePredicate>
94
+ */
95
+ template <bool internal, typename SrcPointerType , typename DestPointerType ,
96
+ typename EdgePredicate>
99
97
static PORTBLAS_INLINE typename std::enable_if<internal>::type load (
100
98
const bool in_range, SrcPointerType src, DestPointerType dest,
101
99
EdgePredicate edge_in_range) {
102
100
PacketType packet{};
103
101
102
+ using address_t = cl::sycl::access::address_space;
104
103
if (in_range) {
105
- using address_t = cl::sycl::access::address_space;
106
104
packet.template load <address_t ::global_space>(
107
105
0 , cl::sycl::multi_ptr<const value_t , address_t ::global_space>(src));
106
+ store (packet, dest);
108
107
} else {
108
+ // avoid writing to variable, instead directly write to
109
+ // shared local memory to avoid race condition experienced
110
+ // with release compiler.
109
111
#pragma unroll
110
- for (index_t i = 0 ; i < packet_size; i++) {
111
- reinterpret_cast <value_t *>(&packet)[i] =
112
- edge_in_range (i) ? *(src + i) : value_t {0 };
113
- }
114
- }
115
- store<trans, ld>(packet, dest);
116
- }
117
- /* ! @brief Store a vector packet into local memory when the source is
118
- * transposed. This will untranspose the elements individually when storing so
119
- * the data in local memory is always consistent.
120
- * @tparam trans Whether the source matrix is transposed or not.
121
- * @tparam ld The leading dimension of the destination memory.*/
122
- template <bool trans, index_t ld, typename DestPointerType>
123
- static PORTBLAS_INLINE typename std::enable_if<trans>::type store (
124
- PacketType &packet, DestPointerType dest) {
125
- using address_t = cl::sycl::access::address_space;
126
- #pragma unroll
127
- for (index_t i = 0 ; i < packet_size; i++) {
128
- value_t val = reinterpret_cast <value_t *>(&packet)[i];
129
- if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
130
- address_t ::local_space>,
131
- DestPointerType>::value) {
132
- using dtype = cl::sycl::half;
133
- *(dest + ld * i) = static_cast <dtype>(val);
134
- } else if constexpr (std::is_same<cl::sycl::multi_ptr<
135
- cl::sycl::ext::oneapi::bfloat16,
136
- address_t ::local_space>,
137
- DestPointerType>::value) {
138
- using dtype = cl::sycl::ext::oneapi::bfloat16;
139
- *(dest + ld * i) = static_cast <dtype>(val);
140
- } else {
141
- using namespace cl ::sycl::ext::oneapi::experimental::matrix;
142
- *(dest + ld * i) = round_to_tf32 (val);
112
+ for (index_t i = 0 ; i < packet_size; i++, dest++, src++) {
113
+ if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
114
+ address_t ::local_space>,
115
+ DestPointerType>::value) {
116
+ using dtype = cl::sycl::half;
117
+ *dest = static_cast <dtype>(edge_in_range (i) ? *src : 0 );
118
+ } else if constexpr (std::is_same<cl::sycl::multi_ptr<
119
+ cl::sycl::ext::oneapi::bfloat16,
120
+ address_t ::local_space>,
121
+ DestPointerType>::value) {
122
+ using namespace cl ::sycl::ext::oneapi;
123
+ *dest = bfloat16 (edge_in_range (i) ? *src : 0 .f );
124
+ } else {
125
+ using namespace cl ::sycl::ext::oneapi::experimental::matrix;
126
+ *dest = edge_in_range (i) ? round_to_tf32 (*src) : 0 .f ;
127
+ }
143
128
}
144
129
}
145
130
}
146
131
147
- /* ! @brief Store a vector packet into local memory when the source is not
148
- * transposed. This will use sycl::vec::store function.
149
- * @tparam trans Whether the source matrix is transposed or not.
150
- * @tparam ld The leading dimension of the destination memory.*/
151
- template <bool trans, int ld, typename DestPointerType>
152
- static PORTBLAS_INLINE typename std::enable_if<!trans>::type store (
153
- PacketType &packet, DestPointerType dest) {
132
+ /* ! @brief Store a vector packet into local memory. This will use
133
+ * sycl::vec::store function.
134
+ */
135
+ template <typename DestPointerType>
136
+ static PORTBLAS_INLINE void store (PacketType &packet, DestPointerType dest) {
154
137
using address_t = cl::sycl::access::address_space;
155
138
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
156
139
address_t ::local_space>,
157
140
DestPointerType>::value) {
158
141
using dtype = cl::sycl::half;
159
- *dest = static_cast <dtype>(packet[0 ]);
142
+ cl::sycl::vec<dtype, vector_size> new_vec{};
143
+ for (index_t i = 0 ; i < packet_size; i++) {
144
+ reinterpret_cast <dtype *>(&new_vec)[i] =
145
+ static_cast <dtype>(reinterpret_cast <value_t *>(&packet)[i]);
146
+ }
147
+ new_vec.template store <address_t ::local_space>(
148
+ 0 , cl::sycl::multi_ptr<dtype, address_t ::local_space>(dest));
160
149
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
161
150
cl::sycl::ext::oneapi::bfloat16,
162
151
address_t ::local_space>,
163
152
DestPointerType>::value) {
164
- using dtype = cl::sycl::ext::oneapi::bfloat16;
165
- *dest = static_cast <dtype>(packet[0 ]);
153
+ // sycl::vec doesn't accept bfloat16 as a valid input type
154
+ // so we need to write the packet elements individually to
155
+ // the shared memory.
156
+ using namespace cl ::sycl::ext::oneapi;
157
+ for (index_t i = 0 ; i < packet_size; i++, dest++) {
158
+ *dest = bfloat16 (reinterpret_cast <value_t *>(&packet)[i]);
159
+ }
166
160
} else {
167
161
using namespace cl ::sycl::ext::oneapi::experimental::matrix;
168
- *dest = round_to_tf32 (packet[0 ]);
162
+ using dtype = float ;
163
+ cl::sycl::vec<dtype, vector_size> new_vec;
164
+ for (index_t i = 0 ; i < packet_size; i++) {
165
+ reinterpret_cast <dtype *>(&new_vec)[i] =
166
+ round_to_tf32 (reinterpret_cast <value_t *>(&packet)[i]);
167
+ }
168
+ new_vec.template store <address_t ::local_space>(
169
+ 0 , cl::sycl::multi_ptr<dtype, address_t ::local_space>(dest));
169
170
}
170
171
}
171
172
};
0 commit comments