forked from ROCm/aiter
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvec_convert.h
More file actions
271 lines (256 loc) · 10.8 KB
/
vec_convert.h
File metadata and controls
271 lines (256 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
// SPDX-License-Identifier: MIT
// Copyright (C) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "aiter_hip_common.h"
namespace ck_tile {
template <typename T, int N>
using vec_t = thread_buffer<T, N>;
// using vec_t = ext_vector_t<T, N>;
using int8x2_v = vec_t<int8_t, 2>;
using fp8x2_v = vec_t<fp8_t, 2>;
using fp16x2_v = vec_t<fp16_t, 2>;
using bf16x2_v = vec_t<bf16_t, 2>;
using fp32x2_v = vec_t<fp32_t, 2>;
struct fp4x2_t
{
using type = uint8_t;
type data;
__host__ __device__ constexpr fp4x2_t() : data{type{}} {}
__host__ __device__ constexpr fp4x2_t(type init) : data{init} {}
};
using fp4x2x2_v = vec_t<fp4x2_t, 2>;
using fp4x2x4_v = vec_t<fp4x2_t, 4>;
using fp4x2x8_v = vec_t<fp4x2_t, 8>;
template <>
struct vector_traits<fp4x2_t>
{
using scalar_type = uint8_t;
static constexpr index_t vector_size = 1;
};
template <>
struct numeric<fp4x2_t>
{
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp32_t max() { return 6.0f; }
};
CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b)
{
fp32x2_v c;
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// use scalar math for RDNA4/3 without v_pk_mul_f32
CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b){
fp32x2_v c;
c[0] = a[0] * b[0];
c[1] = a[1] * b[1];
return c;
}
CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b)
{
int16x2_t c;
static constexpr bool is_e4m3_fnuz =
(numeric_traits<fp8_t>::f8_interpret == fp8_interpretation::E4M3_FNUZ);
static constexpr float d = is_e4m3_fnuz ? 240.0f : 448.0f;
static constexpr float e = is_e4m3_fnuz ? -240.0f : -448.0f;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
"v_med3_f32 %2, %2, %3, %4\n"
"v_cvt_pk_fp8_f32 %0, %1, %2"
: "=v"(c)
: "v"(a), "v"(b), "v"(d), "v"(e));
return bit_cast<fp8x2_v>(c[0]);
}
CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_bf8_f32(fp32_t a, fp32_t b)
{
int16x2_t c;
static constexpr float d = 57344.0f;
static constexpr float e = -57344.0f;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
"v_med3_f32 %2, %2, %3, %4\n"
"v_cvt_pk_bf8_f32 %0, %1, %2"
: "=v"(c)
: "v"(a), "v"(b), "v"(d), "v"(e));
return bit_cast<fp8x2_v>(c[0]);
}
CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f32(fp32_t a, fp32_t b, fp32_t scale)
{
#if defined(__gfx950__)
int16x2_t c;
// permute high bits and low bits to match the order of the original vector
asm volatile("v_cvt_scalef32_pk_fp4_f32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(scale));
return bit_cast<fp4x2_t>(bit_cast<int8x2_t>(c[0])[0]);
#else
return fp4x2_t{};
#endif
}
CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f16(fp16x2_v a, fp32_t scale)
{
#if defined(__gfx950__)
int16x2_t c;
// permute high bits and low bits to match the order of the original vector
asm volatile("v_cvt_scalef32_pk_fp4_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale));
return bit_cast<fp4x2_t>(bit_cast<int8x2_t>(c[0])[0]);
#else
return fp4x2_t{};
#endif
}
CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_bf16(bf16x2_v a, fp32_t scale)
{
#if defined(__gfx950__)
int16x2_t c;
// permute high bits and low bits to match the order of the original vector
asm volatile("v_cvt_scalef32_pk_fp4_bf16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale));
return bit_cast<fp4x2_t>(bit_cast<int8x2_t>(c[0])[0]);
#else
return fp4x2_t{};
#endif
}
// convert any to fp32x?_t one by one
template <typename Y,
typename X,
index_t N,
std::enable_if_t<(std::is_same_v<Y, fp32_t>), bool> = false>
CK_TILE_HOST_DEVICE constexpr vec_t<Y, N> vec_convert(vec_t<X, N> x)
{
using fp32xX_t = vec_t<Y, N>;
fp32xX_t tmp;
for(size_t i = 0; i < N; i++)
{
tmp[i] = type_convert<Y>(x[i]);
}
return tmp;
}
template <typename Y,
typename X,
index_t N,
std::enable_if_t<(N % 2 == 0), bool> = false,
std::enable_if_t<(!(std::is_same_v<Y, fp4x2_t>)), bool> = false>
CK_TILE_HOST_DEVICE constexpr vec_t<Y, N> vec_convert(vec_t<X, N> x, fp32_t inverted_scale)
{
if constexpr(!std::is_same_v<X, fp32_t>)
{
using fp32xX_t = vec_t<fp32_t, N>;
fp32xX_t tmp = vec_convert<fp32_t, X, N>(x);
return vec_convert<Y, fp32_t, N>(tmp, inverted_scale);
}
else
{
// fp32->??
return vec_convert<Y, fp32_t, N>(x, inverted_scale);
}
}
// fp32x2 -> fp8x2
CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inverted_scale)
{
using vec_ti = vector_traits<fp32x2_v>;
constexpr int vec_size = vec_ti::vector_size;
constexpr auto interpret = numeric_traits<fp8_t>::f8_interpret;
fp32x2_v tmp;
#if defined(__gfx11__) || defined(__gfx12__)
tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif
return (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
? amd_assembly_cvt_pk_fp8_f32(tmp[0], tmp[1])
: amd_assembly_cvt_pk_bf8_f32(tmp[0], tmp[1]);
}
// fp32x2 -> int8x2
CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale)
{
fp32x2_v tmp;
#if defined(__gfx11__) || defined(__gfx12__)
tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif
int8x2_v out;
out[0] = static_cast<int8_t>(tmp[0]);
out[1] = static_cast<int8_t>(tmp[1]);
return out;
}
// fp32x2 -> fp4x2
CK_TILE_HOST_DEVICE constexpr fp4x2_t fp32x2_t_to_fp4x2_t(fp32x2_v x, fp32_t inverted_scale)
{
return amd_assembly_cvt_scalef32_pk_fp4_f32(x[0], x[1], inverted_scale);
}
// fp16x2 -> fp4x2
CK_TILE_HOST_DEVICE constexpr fp4x2_t fp16x2_t_to_fp4x2_t(fp16x2_v x, fp32_t inverted_scale)
{
return amd_assembly_cvt_scalef32_pk_fp4_f16(x, inverted_scale);
}
// bf16x2 -> fp4x2
CK_TILE_HOST_DEVICE constexpr fp4x2_t bf16x2_t_to_fp4x2_t(bf16x2_v x, fp32_t inverted_scale)
{
return amd_assembly_cvt_scalef32_pk_fp4_bf16(x, inverted_scale);
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_, vec_size_) \
template <> \
CK_TILE_HOST_DEVICE constexpr vec_t<dtype_##_t, vec_size_> \
vec_convert<dtype_##_t, stype_##_t, vec_size_>(vec_t<stype_##_t, vec_size_> x, \
fp32_t inverted_scale) \
{ \
constexpr int iter_num = vec_size_ / 2; \
vec_t<dtype_##_t, vec_size_> out; \
using vec_i2 = vec_t<stype_##_t, 2>; \
using vec_o2 = vec_t<dtype_##_t, 2>; \
_Pragma("unroll") for(size_t i = 0; i < iter_num; i++) \
{ \
vec_o2 tmp = stype_##x2##_t_to_##dtype_##x2##_t(x.template get_as<vec_i2>()(i), \
inverted_scale); \
out.template get_as<vec_o2>()(i) = tmp; \
} \
return out; \
}
CK_TILE_TYPE_CONVERT(fp8, fp32, 2)
CK_TILE_TYPE_CONVERT(fp8, fp32, 4)
CK_TILE_TYPE_CONVERT(fp8, fp32, 8)
CK_TILE_TYPE_CONVERT(fp8, fp32, 16)
CK_TILE_TYPE_CONVERT(fp8, fp32, 32)
CK_TILE_TYPE_CONVERT(int8, fp32, 2)
CK_TILE_TYPE_CONVERT(int8, fp32, 4)
CK_TILE_TYPE_CONVERT(int8, fp32, 8)
CK_TILE_TYPE_CONVERT(int8, fp32, 16)
CK_TILE_TYPE_CONVERT(int8, fp32, 32)
#undef CK_TILE_TYPE_CONVERT
// 4 bit vec convert
// convert any to fp32x?_t one by one
template <typename Y,
typename X,
index_t N,
std::enable_if_t<(N % 2 == 0), bool> = false,
std::enable_if_t<((std::is_same_v<Y, fp4x2_t>)), bool> = false>
CK_TILE_HOST_DEVICE constexpr vec_t<Y, N / 2> vec_convert(vec_t<X, N> x, fp32_t inverted_scale);
#define CK_TILE_TYPE_CONVERT(dtype_, stype_, vec_size_) \
template <> \
CK_TILE_HOST_DEVICE constexpr vec_t<dtype_##_t, vec_size_ / 2> \
vec_convert<dtype_##_t, stype_##_t, vec_size_>(vec_t<stype_##_t, vec_size_> x, \
fp32_t inverted_scale) \
{ \
constexpr int iter_num = vec_size_ / 2; \
vec_t<dtype_##_t, iter_num> out; \
using vec_i2 = vec_t<stype_##_t, 2>; \
using vec_o2 = dtype_##_t; \
_Pragma("unroll") for(size_t i = 0; i < iter_num; i++) \
{ \
vec_o2 tmp = \
stype_##x2##_t_to_##dtype_##_t(x.template get_as<vec_i2>()(i), inverted_scale); \
out.template get_as<vec_o2>()(i) = tmp; \
} \
return out; \
}
CK_TILE_TYPE_CONVERT(fp4x2, fp32, 4)
CK_TILE_TYPE_CONVERT(fp4x2, fp32, 8)
CK_TILE_TYPE_CONVERT(fp4x2, fp32, 16)
CK_TILE_TYPE_CONVERT(fp4x2, fp32, 32)
CK_TILE_TYPE_CONVERT(fp4x2, fp16, 4)
CK_TILE_TYPE_CONVERT(fp4x2, fp16, 8)
CK_TILE_TYPE_CONVERT(fp4x2, fp16, 16)
CK_TILE_TYPE_CONVERT(fp4x2, fp16, 32)
CK_TILE_TYPE_CONVERT(fp4x2, bf16, 4)
CK_TILE_TYPE_CONVERT(fp4x2, bf16, 8)
CK_TILE_TYPE_CONVERT(fp4x2, bf16, 16)
CK_TILE_TYPE_CONVERT(fp4x2, bf16, 32)
#undef CK_TILE_TYPE_CONVERT
} // namespace ck_tile