Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit 861b310

Browse files
Fix joint_matrix implementation to match latest api (#491)
* Added tests for joint_matrix implementation as well --------- Co-authored-by: pgorlani <[email protected]>
1 parent 2f149cb commit 861b310

24 files changed

+2051
-306
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ export(EXPORT portblas
212212

213213
option(BLAS_ENABLE_TESTING "Whether to enable testing" ON)
214214
option(ENABLE_EXPRESSION_TESTS "Whether to build expression tree fusion tests" OFF)
215+
option(ENABLE_JOINTMATRIX_TESTS "Whether to build joint_matrix GEMM tests" OFF)
215216
if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_TESTING)
216217
message(STATUS "Tests are disabled when installing portBLAS in header only mode")
217218
set(BLAS_ENABLE_TESTING OFF)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ Some of the supported options are:
458458
| `CMAKE_INSTALL_PREFIX` | path | Specify the install location, used when invoking `ninja install` |
459459
| `BUILD_SHARED_LIBS` | `ON`/`OFF` | Build as shared library (`ON` by default) |
460460
| `ENABLE_EXPRESSION_TESTS` | `ON`/`OFF` | Build additional tests that use the header-only framework (e.g to test expression trees); `OFF` by default |
461+
| `ENABLE_JOINTMATRIX_TESTS` | `ON`/`OFF` | Build additional tests that use joint_matrix extension; `OFF` by default |
461462
| `BLAS_VERIFY_BENCHMARK` | `ON`/`OFF` | Verify the results of the benchmarks instead of only measuring the performance. See the documentation of the benchmarks for more details. `ON` by default |
462463
| `BLAS_MEMPOOL_BENCHMARK` | `ON`/`OFF` | Determines whether to enable the scratchpad memory pool for benchmark execution. `OFF` by default |
463464
| `BLAS_ENABLE_CONST_INPUT` | `ON`/`OFF` | Determines whether to enable kernel instantiation with const input buffer (`ON` by default) |

benchmark/portblas/blas3/trsm.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,13 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, char side,
9797
}
9898

9999
std::ostringstream err_stream;
100-
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "")) {
100+
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
101+
if (!utils::compare_vectors(b_temp, x_ref, err_stream, "",
102+
(en_joint_matrix != NULL) &&
103+
(std::is_same<scalar_t, float>::value) &&
104+
(*en_joint_matrix == '1')
105+
? 2
106+
: 1)) {
101107
const std::string& err_str = err_stream.str();
102108
state.SkipWithError(err_str.c_str());
103109
*success = false;
@@ -181,8 +187,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
181187
};
182188
benchmark::RegisterBenchmark(
183189
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
184-
side, uplo, trans, diag, m, n,
185-
mem_type).c_str(),
190+
side, uplo, trans, diag, m, n, mem_type)
191+
.c_str(),
186192
BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success)
187193
->UseRealTime();
188194
}
@@ -193,16 +199,17 @@ void register_benchmark(blas_benchmark::Args& args,
193199
blas::SB_Handle* sb_handle_ptr, bool* success) {
194200
auto trsm_params = blas_benchmark::utils::get_trsm_params<scalar_t>(args);
195201
register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
196-
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, trsm_params);
202+
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
203+
trsm_params);
197204
#ifdef SB_ENABLE_USM
198205
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
199206
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, trsm_params);
200207
#endif
201208
}
202209

203210
namespace blas_benchmark {
204-
void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr,
205-
bool* success) {
211+
void create_benchmark(blas_benchmark::Args& args,
212+
blas::SB_Handle* sb_handle_ptr, bool* success) {
206213
BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success);
207214
}
208215
} // namespace blas_benchmark

common/include/common/float_comparison.hpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,20 @@ scalar_t clamp_to_limits(scalar_t v) {
115115
* Indicates the tolerated margin for relative differences
116116
*/
117117
template <typename scalar_t>
118-
inline scalar_t getRelativeErrorMargin() {
118+
inline scalar_t getRelativeErrorMargin(const int32_t margin_multiplier = 1) {
119119
/* Measured empirically with gemm. The dimensions of the matrices (even k)
120120
* don't seem to have an impact on the observed relative differences
121121
* In the cases where the relative error is relevant (non close to zero),
122122
* relative differences of up to 0.002 were observed for float
123123
*/
124-
return static_cast<scalar_t>(0.005);
124+
scalar_t margin = 0.005;
125+
// increase error margin for mixed precision calculation
126+
// for trsm operator.
127+
return margin * margin_multiplier;
125128
}
126129

127130
template <>
128-
inline double getRelativeErrorMargin<double>() {
131+
inline double getRelativeErrorMargin<double>(const int32_t) {
129132
/* Measured empirically with gemm. The dimensions of the matrices (even k)
130133
* don't seem to have an impact on the observed relative differences
131134
* In the cases where the relative error is relevant (non close to zero),
@@ -135,7 +138,7 @@ inline double getRelativeErrorMargin<double>() {
135138
}
136139

137140
template <>
138-
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
141+
inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>(const int32_t) {
139142
// Measured empirically with gemm
140143
return 0.05f;
141144
}
@@ -145,16 +148,19 @@ inline cl::sycl::half getRelativeErrorMargin<cl::sycl::half>() {
145148
* scalars are close to 0)
146149
*/
147150
template <typename scalar_t>
148-
inline scalar_t getAbsoluteErrorMargin() {
151+
inline scalar_t getAbsoluteErrorMargin(const int32_t margin_multiplier = 1) {
149152
/* Measured empirically with gemm.
150153
* In the cases where the relative error is irrelevant (close to zero),
151154
* absolute differences of up to 0.0006 were observed for float
152155
*/
153-
return 0.001f;
156+
scalar_t margin = 0.001f;
157+
// increase error margin for mixed precision calculation
158+
// for trsm operator.
159+
return margin * margin_multiplier;
154160
}
155161

156162
template <>
157-
inline double getAbsoluteErrorMargin<double>() {
163+
inline double getAbsoluteErrorMargin<double>(const int32_t) {
158164
/* Measured empirically with gemm.
159165
* In the cases where the relative error is irrelevant (close to zero),
160166
* absolute differences of up to 10^-12 were observed for double
@@ -163,7 +169,7 @@ inline double getAbsoluteErrorMargin<double>() {
163169
}
164170

165171
template <>
166-
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
172+
inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>(const int32_t) {
167173
// Measured empirically with gemm.
168174
return 1.0f;
169175
}
@@ -172,7 +178,8 @@ inline cl::sycl::half getAbsoluteErrorMargin<cl::sycl::half>() {
172178
* Compare two scalars and returns false if the difference is not acceptable.
173179
*/
174180
template <typename scalar_t, typename epsilon_t = scalar_t>
175-
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
181+
inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2,
182+
const int32_t margin_multiplier = 1) {
176183
// Shortcut, also handles case where both are zero
177184
if (scalar1 == scalar2) {
178185
return true;
@@ -187,12 +194,14 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
187194

188195
// Close to zero, the relative error doesn't work, use absolute error
189196
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
190-
absolute_diff < getAbsoluteErrorMargin<epsilon_t>()) {
191-
return (absolute_diff < getAbsoluteErrorMargin<epsilon_t>());
197+
absolute_diff < getAbsoluteErrorMargin<epsilon_t>(margin_multiplier)) {
198+
return (absolute_diff <
199+
getAbsoluteErrorMargin<epsilon_t>(margin_multiplier));
192200
}
193201
// Use relative error
194202
const auto absolute_sum = utils::abs(scalar1) + utils::abs(scalar2);
195-
return (absolute_diff / absolute_sum) < getRelativeErrorMargin<epsilon_t>();
203+
return (absolute_diff / absolute_sum) <
204+
getRelativeErrorMargin<epsilon_t>(margin_multiplier);
196205
}
197206

198207
/**
@@ -206,15 +215,16 @@ template <typename scalar_t, typename epsilon_t = scalar_t>
206215
inline bool compare_vectors(std::vector<scalar_t> const& vec,
207216
std::vector<scalar_t> const& ref,
208217
std::ostream& err_stream = std::cerr,
209-
std::string end_line = "\n") {
218+
std::string end_line = "\n",
219+
const int32_t margin_multiplier = 1) {
210220
if (vec.size() != ref.size()) {
211221
err_stream << "Error: tried to compare vectors of different sizes"
212222
<< std::endl;
213223
return false;
214224
}
215225

216226
for (int i = 0; i < vec.size(); ++i) {
217-
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i])) {
227+
if (!almost_equal<scalar_t, epsilon_t>(vec[i], ref[i], margin_multiplier)) {
218228
err_stream << "Value mismatch at index " << i << ": " << vec[i]
219229
<< "; expected " << ref[i] << end_line;
220230
return false;

src/operations/blas3/gemm_load_store_joint_matrix.hpp

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,16 @@ struct PacketizeJointMatrix {
5757

5858
/*! @brief Performs a coalesced non-vectorized load when the current block is
5959
* not internal.
60-
* @tparam trans Whether the source matrix is transposed or not.
6160
* @tparam internal True if the current block is internal and no bounds
6261
* checking is required.
63-
* @tparam ld The leading dimension of the destination memory.
6462
*/
6563

66-
template <bool trans, bool internal, int ld, typename SrcPointerType,
67-
typename DestPointerType, typename EdgePredicate>
64+
template <bool internal, typename SrcPointerType, typename DestPointerType,
65+
typename EdgePredicate>
6866
static PORTBLAS_INLINE typename std::enable_if<!internal>::type load(
6967
const bool in_range, SrcPointerType src, DestPointerType dest,
7068
EdgePredicate) {
71-
value_t val = in_range ? *(src) : value_t{0};
69+
value_t val = in_range ? *src : value_t{0};
7270
using address_t = cl::sycl::access::address_space;
7371
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
7472
address_t::local_space>,
@@ -79,93 +77,96 @@ struct PacketizeJointMatrix {
7977
cl::sycl::ext::oneapi::bfloat16,
8078
address_t::local_space>,
8179
DestPointerType>::value) {
82-
using dtype = cl::sycl::ext::oneapi::bfloat16;
83-
*dest = static_cast<dtype>(val);
80+
using namespace cl::sycl::ext::oneapi;
81+
*dest = bfloat16(val);
8482
} else {
8583
using namespace cl::sycl::ext::oneapi::experimental::matrix;
8684
*dest = round_to_tf32(val);
8785
}
8886
}
87+
8988
/*! @brief Performs a vectorised load using sycl::vec::load when the current
9089
* block is internal. In the case where k < the
9190
* number of elements being loaded then edge loads will be element wise with
9291
* additional bounds checking.
93-
* @tparam trans Whether the source matrix is transposed or not.
9492
* @tparam internal True if the current block is internal and no bounds
9593
* checking is required.
96-
* @tparam ld The leading dimension of the destination memory. */
97-
template <bool trans, bool internal, index_t ld, typename SrcPointerType,
98-
typename DestPointerType, typename EdgePredicate>
94+
*/
95+
template <bool internal, typename SrcPointerType, typename DestPointerType,
96+
typename EdgePredicate>
9997
static PORTBLAS_INLINE typename std::enable_if<internal>::type load(
10098
const bool in_range, SrcPointerType src, DestPointerType dest,
10199
EdgePredicate edge_in_range) {
102100
PacketType packet{};
103101

102+
using address_t = cl::sycl::access::address_space;
104103
if (in_range) {
105-
using address_t = cl::sycl::access::address_space;
106104
packet.template load<address_t::global_space>(
107105
0, cl::sycl::multi_ptr<const value_t, address_t::global_space>(src));
106+
store(packet, dest);
108107
} else {
108+
// avoid writing to variable, instead directly write to
109+
// shared local memory to avoid race condition experienced
110+
// with release compiler.
109111
#pragma unroll
110-
for (index_t i = 0; i < packet_size; i++) {
111-
reinterpret_cast<value_t *>(&packet)[i] =
112-
edge_in_range(i) ? *(src + i) : value_t{0};
113-
}
114-
}
115-
store<trans, ld>(packet, dest);
116-
}
117-
/*! @brief Store a vector packet into local memory when the source is
118-
* transposed. This will untranspose the elements individually when storing so
119-
* the data in local memory is always consistent.
120-
* @tparam trans Whether the source matrix is transposed or not.
121-
* @tparam ld The leading dimension of the destination memory.*/
122-
template <bool trans, index_t ld, typename DestPointerType>
123-
static PORTBLAS_INLINE typename std::enable_if<trans>::type store(
124-
PacketType &packet, DestPointerType dest) {
125-
using address_t = cl::sycl::access::address_space;
126-
#pragma unroll
127-
for (index_t i = 0; i < packet_size; i++) {
128-
value_t val = reinterpret_cast<value_t *>(&packet)[i];
129-
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
130-
address_t::local_space>,
131-
DestPointerType>::value) {
132-
using dtype = cl::sycl::half;
133-
*(dest + ld * i) = static_cast<dtype>(val);
134-
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
135-
cl::sycl::ext::oneapi::bfloat16,
136-
address_t::local_space>,
137-
DestPointerType>::value) {
138-
using dtype = cl::sycl::ext::oneapi::bfloat16;
139-
*(dest + ld * i) = static_cast<dtype>(val);
140-
} else {
141-
using namespace cl::sycl::ext::oneapi::experimental::matrix;
142-
*(dest + ld * i) = round_to_tf32(val);
112+
for (index_t i = 0; i < packet_size; i++, dest++, src++) {
113+
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
114+
address_t::local_space>,
115+
DestPointerType>::value) {
116+
using dtype = cl::sycl::half;
117+
*dest = static_cast<dtype>(edge_in_range(i) ? *src : 0);
118+
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
119+
cl::sycl::ext::oneapi::bfloat16,
120+
address_t::local_space>,
121+
DestPointerType>::value) {
122+
using namespace cl::sycl::ext::oneapi;
123+
*dest = bfloat16(edge_in_range(i) ? *src : 0.f);
124+
} else {
125+
using namespace cl::sycl::ext::oneapi::experimental::matrix;
126+
*dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f;
127+
}
143128
}
144129
}
145130
}
146131

147-
/*! @brief Store a vector packet into local memory when the source is not
148-
* transposed. This will use sycl::vec::store function.
149-
* @tparam trans Whether the source matrix is transposed or not.
150-
* @tparam ld The leading dimension of the destination memory.*/
151-
template <bool trans, int ld, typename DestPointerType>
152-
static PORTBLAS_INLINE typename std::enable_if<!trans>::type store(
153-
PacketType &packet, DestPointerType dest) {
132+
/*! @brief Store a vector packet into local memory. This will use
133+
* sycl::vec::store function.
134+
*/
135+
template <typename DestPointerType>
136+
static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) {
154137
using address_t = cl::sycl::access::address_space;
155138
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
156139
address_t::local_space>,
157140
DestPointerType>::value) {
158141
using dtype = cl::sycl::half;
159-
*dest = static_cast<dtype>(packet[0]);
142+
cl::sycl::vec<dtype, vector_size> new_vec{};
143+
for (index_t i = 0; i < packet_size; i++) {
144+
reinterpret_cast<dtype *>(&new_vec)[i] =
145+
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
146+
}
147+
new_vec.template store<address_t::local_space>(
148+
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
160149
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
161150
cl::sycl::ext::oneapi::bfloat16,
162151
address_t::local_space>,
163152
DestPointerType>::value) {
164-
using dtype = cl::sycl::ext::oneapi::bfloat16;
165-
*dest = static_cast<dtype>(packet[0]);
153+
// sycl::vec doesn't accept bfloat16 as a valid input type
154+
// so we need to write the packet elements individually to
155+
// the shared memory.
156+
using namespace cl::sycl::ext::oneapi;
157+
for (index_t i = 0; i < packet_size; i++, dest++) {
158+
*dest = bfloat16(reinterpret_cast<value_t *>(&packet)[i]);
159+
}
166160
} else {
167161
using namespace cl::sycl::ext::oneapi::experimental::matrix;
168-
*dest = round_to_tf32(packet[0]);
162+
using dtype = float;
163+
cl::sycl::vec<dtype, vector_size> new_vec;
164+
for (index_t i = 0; i < packet_size; i++) {
165+
reinterpret_cast<dtype *>(&new_vec)[i] =
166+
round_to_tf32(reinterpret_cast<value_t *>(&packet)[i]);
167+
}
168+
new_vec.template store<address_t::local_space>(
169+
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
169170
}
170171
}
171172
};

0 commit comments

Comments
 (0)