Skip to content

Commit 5384c00

Browse files
CaoEMingxuZh
andauthored
Leverage flash attention for fp16 first_token_masked_mha (#2846)
* leverage flash attention for fp16 first_token_masked_mha * fix format --------- Co-authored-by: Zhang, Mingxu <[email protected]>
1 parent 3727406 commit 5384c00

File tree

2 files changed

+118
-17
lines changed

2 files changed

+118
-17
lines changed

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,30 +1294,28 @@ first_token_masked_mha(
12941294
auto key_lenght = key.size(1);
12951295
auto kv_head_num = key.size(2);
12961296
auto head_size = key.size(3);
1297-
if (origin_type == at::kHalf) {
1298-
key = key.to(at::kFloat);
1299-
query = query.to(at::kFloat);
1300-
value = value.to(at::kFloat);
1301-
key_cache = key_cache.to(at::kFloat);
1302-
value_cache = value_cache.to(at::kFloat);
1303-
}
13041297
if (add_casual_mask) {
1305-
auto casual_mask =
1306-
at::full({query_length, key_lenght}, -1e6, query.options());
1298+
auto casual_mask = at::full(
1299+
{query_length, key_lenght},
1300+
origin_type == at::kHalf ? -6e4 : -1e6,
1301+
query.options());
13071302
casual_mask = at::triu(casual_mask, 1);
13081303
casual_mask = casual_mask.unsqueeze(0).unsqueeze(0);
13091304
attention_mask = attention_mask + casual_mask;
13101305
}
1311-
if (key.scalar_type() != at::kBFloat16 && key.scalar_type() != at::kFloat) {
1306+
if (key.scalar_type() != at::kBFloat16 && key.scalar_type() != at::kFloat &&
1307+
key.scalar_type() != at::kHalf) {
13121308
TORCH_CHECK(
13131309
false,
1314-
"key and value must be float or bfloat16 to use ipex::masked_multihead_self_attention_kernel_impl");
1310+
"key and value must be float, float16 or bfloat16 to use ipex::masked_multihead_self_attention_kernel_impl");
13151311
}
13161312
if (key.scalar_type() == at::kFloat) {
13171313
copy_key_value<float>(key_cache, key, value_cache, value, beam_batch);
1318-
} else {
1314+
} else if (key.scalar_type() == at::kBFloat16) {
13191315
copy_key_value<at::BFloat16>(
13201316
key_cache, key, value_cache, value, beam_batch);
1317+
} else {
1318+
copy_key_value<at::Half>(key_cache, key, value_cache, value, beam_batch);
13211319
}
13221320
// support MGQ/MQA
13231321
// expand the head dimensiopn of key/value to be same to the query
@@ -1344,6 +1342,11 @@ first_token_masked_mha(
13441342
attention_mask,
13451343
1. / scale_attn));
13461344
} else {
1345+
if (origin_type == at::kHalf) {
1346+
key = key.to(at::kFloat);
1347+
query = query.to(at::kFloat);
1348+
value = value.to(at::kFloat);
1349+
}
13471350
key = key.permute({0, 2, 1, 3});
13481351
query = query.permute({0, 2, 1, 3});
13491352
value = value.permute({0, 2, 1, 3});
@@ -1355,13 +1358,9 @@ first_token_masked_mha(
13551358
attn_outputs = attn_weights.matmul(value);
13561359
if (origin_type == at::kHalf) {
13571360
attn_weights = attn_weights.to(origin_type);
1361+
attn_outputs = attn_outputs.to(origin_type);
13581362
}
13591363
}
1360-
if (origin_type == at::kHalf) {
1361-
attn_outputs = attn_outputs.to(origin_type);
1362-
key_cache = key_cache.to(origin_type);
1363-
value_cache = value_cache.to(origin_type);
1364-
}
13651364
return std::make_tuple(
13661365
attn_outputs, attn_weights, key_cache, value_cache, beam_idx);
13671366
}

csrc/cpu/vec/vec512/vec512_half.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,75 @@ IPEX_FORCE_INLINE void move_ker(at::Half* out, const float* in, int64_t len) {
5353
cvt_fp32_to_fp16(out, in, len);
5454
}
5555

56+
template <>
57+
IPEX_FORCE_INLINE void move_ker(
58+
at::Half* out,
59+
const at::Half* in,
60+
int64_t len) {
61+
int64_t i = 0;
62+
#pragma unroll(4)
63+
for (i = 0; i < len - 31; i += 32) {
64+
auto in0 = _mm512_loadu_si512(in + i);
65+
_mm512_storeu_si512(out + i, in0);
66+
}
67+
68+
if (i < len) {
69+
auto mask = (1 << (len - i)) - 1;
70+
auto in0 = _mm512_maskz_loadu_epi16(mask, in + i);
71+
_mm512_mask_storeu_epi16(out + i, mask, in0);
72+
}
73+
}
74+
75+
static IPEX_FORCE_INLINE void zero_ker(at::Half* out, int64_t len) {
76+
int64_t i = 0;
77+
__m512i zero_512 = _mm512_setzero_si512();
78+
#pragma unroll(4)
79+
for (i = 0; i < len - 31; i += 32) {
80+
_mm512_storeu_si512(out + i, zero_512);
81+
}
82+
83+
if (i < len) {
84+
auto mask = ((1 << (len - i)) - 1);
85+
_mm512_mask_storeu_epi16(out + i, mask, zero_512);
86+
}
87+
}
88+
89+
template <>
90+
IPEX_FORCE_INLINE void add_ker(
91+
at::Half* inout,
92+
const at::Half* in,
93+
int64_t len) {
94+
int64_t i = 0;
95+
#pragma unroll(2)
96+
for (i = 0; i < len - 31; i += 32) {
97+
auto inout1 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i)));
98+
auto inout2 =
99+
cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i + 16)));
100+
auto in1 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i)));
101+
auto in2 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i + 16)));
102+
inout1 = _mm512_add_ps(inout1, in1);
103+
inout2 = _mm512_add_ps(inout2, in2);
104+
_mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_fp16(inout1));
105+
_mm256_storeu_si256((__m256i*)(inout + i + 16), cvt_fp32_to_fp16(inout2));
106+
}
107+
108+
if (i < len - 15) {
109+
auto inout1 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i)));
110+
auto in1 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i)));
111+
inout1 = _mm512_add_ps(inout1, in1);
112+
_mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_fp16(inout1));
113+
i += 16;
114+
}
115+
116+
if (i < len) {
117+
auto mask = (1 << (len - i)) - 1;
118+
auto inout1 = cvt_fp16_to_fp32(_mm256_maskz_loadu_epi16(mask, inout + i));
119+
auto in1 = cvt_fp16_to_fp32(_mm256_maskz_loadu_epi16(mask, in + i));
120+
inout1 = _mm512_add_ps(inout1, in1);
121+
_mm256_mask_storeu_epi16(inout + i, mask, cvt_fp32_to_fp16(inout1));
122+
}
123+
}
124+
56125
template <>
57126
IPEX_FORCE_INLINE void add_ker(float* inout, const at::Half* in, int64_t len) {
58127
int64_t i = 0;
@@ -85,6 +154,39 @@ IPEX_FORCE_INLINE void add_ker(float* inout, const at::Half* in, int64_t len) {
85154
}
86155
}
87156

157+
template <>
158+
IPEX_FORCE_INLINE void add_ker(at::Half* inout, const float* in, int64_t len) {
159+
int64_t i = 0;
160+
#pragma unroll(2)
161+
for (i = 0; i < len - 31; i += 32) {
162+
auto in1 = _mm512_loadu_ps(in + i);
163+
auto in2 = _mm512_loadu_ps(in + i + 16);
164+
auto inout1 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i)));
165+
auto inout2 =
166+
cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i + 16)));
167+
inout1 = _mm512_add_ps(inout1, in1);
168+
inout2 = _mm512_add_ps(inout2, in2);
169+
_mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_fp16(inout1));
170+
_mm256_storeu_si256((__m256i*)(inout + i + 16), cvt_fp32_to_fp16(inout2));
171+
}
172+
173+
if (i < len - 15) {
174+
auto in1 = _mm512_loadu_ps(in + i);
175+
auto inout1 = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i)));
176+
inout1 = _mm512_add_ps(inout1, in1);
177+
_mm256_storeu_si256((__m256i*)(inout + i), cvt_fp32_to_fp16(inout1));
178+
i += 16;
179+
}
180+
181+
if (i < len) {
182+
auto mask = (1 << (len - i)) - 1;
183+
auto in1 = _mm512_maskz_loadu_ps(mask, in + i);
184+
auto inout1 = cvt_fp16_to_fp32(_mm256_maskz_loadu_epi16(mask, inout + i));
185+
inout1 = _mm512_add_ps(inout1, in1);
186+
_mm256_mask_storeu_epi16(inout + i, mask, cvt_fp32_to_fp16(inout1));
187+
}
188+
}
189+
88190
} // namespace kernel
89191
} // namespace cpu
90192
} // namespace torch_ipex

0 commit comments

Comments
 (0)