Skip to content

Commit 3727406

Browse files
authored
refactor rope part and optimize mask memory for Phi3 (#2828)
1 parent b7bc2ad commit 3727406

File tree

14 files changed

+516
-160
lines changed

14 files changed

+516
-160
lines changed

csrc/cpu/aten/MaskedMultiHeadAttention.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace torch_ipex {
66
namespace cpu {
77

88
IPEX_DEFINE_DISPATCH(masked_multihead_self_attention_kernel_stub);
9+
IPEX_DEFINE_DISPATCH(prepare_4d_causal_attention_mask_kernel_stub);
910

1011
/*
1112
*Caculate the masked multihead attention for decoder layer in decoder only
@@ -54,6 +55,21 @@ masked_multihead_self_attention_forward_cpu(
5455
add_casual_mask);
5556
}
5657

58+
at::Tensor prepare_4d_causal_attention_mask_forward_cpu(
59+
at::Tensor& attention_mask,
60+
at::Tensor& inputs_embeds,
61+
at::Tensor& past_kv_len,
62+
at::Tensor& finfo_min,
63+
int64_t sliding_window) {
64+
return prepare_4d_causal_attention_mask_kernel_stub(
65+
kCPU,
66+
attention_mask,
67+
inputs_embeds,
68+
past_kv_len,
69+
finfo_min,
70+
sliding_window);
71+
}
72+
5773
} // namespace cpu
5874
} // namespace torch_ipex
5975

@@ -69,4 +85,13 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
6985
c10::DispatchKey::CPU,
7086
torch_ipex::cpu::masked_multihead_self_attention_forward_cpu);
7187
}
88+
89+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
90+
m.def(
91+
"prepare_4d_causal_attention_mask(Tensor attention_mask, Tensor inputs_embeds, Tensor past_kv_len, Tensor finfo_min, int sliding_window)-> (Tensor)");
92+
m.impl(
93+
"prepare_4d_causal_attention_mask",
94+
c10::DispatchKey::CPU,
95+
torch_ipex::cpu::prepare_4d_causal_attention_mask_forward_cpu);
96+
}
7297
} // namespace

csrc/cpu/aten/MaskedMultiHeadAttention.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ masked_multihead_self_attention(
2222
const c10::optional<at::Tensor>& head_mask /* optional */,
2323
const c10::optional<at::Tensor>& attention_mask /* optional */,
2424
c10::optional<bool> add_casual_mask /* optional */);
25-
}
25+
26+
at::Tensor prepare_4d_causal_attention_mask_forward_cpu(
27+
at::Tensor& attention_mask,
28+
at::Tensor& inputs_embeds,
29+
at::Tensor& past_kv_len,
30+
at::Tensor& finfo_min,
31+
int64_t sliding_window);
32+
} // namespace
2633

2734
using masked_multihead_self_attention_kernel_fn =
2835
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> (*)(
@@ -42,6 +49,16 @@ using masked_multihead_self_attention_kernel_fn =
4249
IPEX_DECLARE_DISPATCH(
4350
masked_multihead_self_attention_kernel_fn,
4451
masked_multihead_self_attention_kernel_stub);
52+
using prepare_4d_causal_attention_mask_kernel_fn = at::Tensor (*)(
53+
at::Tensor& attention_mask,
54+
at::Tensor& inputs_embeds,
55+
at::Tensor& past_kv_len,
56+
at::Tensor& finfo_min,
57+
int64_t sliding_window);
58+
59+
IPEX_DECLARE_DISPATCH(
60+
prepare_4d_causal_attention_mask_kernel_fn,
61+
prepare_4d_causal_attention_mask_kernel_stub);
4562

4663
} // namespace cpu
4764
} // namespace torch_ipex

0 commit comments

Comments
 (0)