Skip to content

Commit e55f485

Browse files
committed
make a WA to support both CPU and XPU
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 8aef4c5 commit e55f485

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,42 @@ def attention(
9393
if use_ipex:
9494
import intel_extension_for_pytorch as ipex
9595

96-
return ipex.llm.functional.varlen_attention(
97-
q.contiguous() if q.device.type == "xpu" else q,
98-
k.contiguous() if k.device.type == "xpu" else k,
99-
v.contiguous() if v.device.type == "xpu" else v,
100-
out,
101-
cu_seqlens,
102-
cu_seqlens,
103-
None,
104-
max_s,
105-
max_s,
106-
0,
107-
softmax_scale,
108-
zero_tensors=False,
109-
is_causal=False,
110-
return_softmax=False,
111-
gen_=None,
112-
)
96+
if q.device.type == "xpu":
97+
return ipex.llm.functional.varlen_attention(
98+
q.contiguous(),
99+
k.contiguous(),
100+
v.contiguous(),
101+
out,
102+
cu_seqlens,
103+
cu_seqlens,
104+
None,
105+
max_s,
106+
max_s,
107+
0,
108+
softmax_scale,
109+
zero_tensors=False,
110+
is_causal=False,
111+
return_softmax=False,
112+
gen_=None,
113+
)
114+
elif q.device.type == "cpu":
115+
return ipex.llm.functional.varlen_attention(
116+
q,
117+
k,
118+
v,
119+
out,
120+
cu_seqlens,
121+
cu_seqlens,
122+
max_s,
123+
max_s,
124+
0,
125+
softmax_scale,
126+
zero_tensors=False,
127+
is_causal=False,
128+
return_softmax=False,
129+
gen_=None,
130+
)
131+
113132
elif is_hpu:
114133
return hpu_attn(
115134
q,

0 commit comments

Comments
 (0)