Skip to content

Commit a07c604

Browse files
add alibi param on varlen_attention to align API (#3750)
* add alibi param on varlen_attention to align API * add ut * update ut --------- Co-authored-by: jianan-gu <[email protected]>
1 parent 2978fd6 commit a07c604

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

intel_extension_for_pytorch/llm/functional/fusions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def varlen_attention(
192192
out: torch.Tensor,
193193
seqlen_q: torch.Tensor,
194194
seqlen_k: torch.Tensor,
195+
alibi_slopes: torch.Tensor,
195196
max_seqlen_q: int,
196197
max_seqlen_k: int,
197198
pdropout: float,
@@ -240,6 +241,7 @@ def varlen_attention(
240241
out,
241242
seqlen_q,
242243
seqlen_k,
244+
alibi_slopes,
243245
max_seqlen_q,
244246
max_seqlen_k,
245247
pdropout,

intel_extension_for_pytorch/llm/modules/mha_fusion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def apply_function(
362362
out: torch.Tensor,
363363
seqlen_q: torch.Tensor,
364364
seqlen_k: torch.Tensor,
365+
alibi_slopes: torch.Tensor,
365366
max_seqlen_q: int,
366367
max_seqlen_k: int,
367368
pdropout: float,
@@ -383,6 +384,7 @@ def apply_function(
383384
out,
384385
seqlen_q,
385386
seqlen_k,
387+
alibi_slopes,
386388
max_seqlen_q,
387389
max_seqlen_k,
388390
pdropout,
@@ -404,6 +406,7 @@ def forward(
404406
out: torch.Tensor,
405407
seqlen_q: torch.Tensor,
406408
seqlen_k: torch.Tensor,
409+
alibi_slopes: torch.Tensor,
407410
max_seqlen_q: int,
408411
max_seqlen_k: int,
409412
pdropout: float,
@@ -412,6 +415,9 @@ def forward(
412415
is_causal: bool,
413416
return_softmax: bool,
414417
gen_: torch.Generator,
418+
window_size_left: int,
419+
window_size_right: int,
420+
softcap: float,
415421
):
416422
runtime_module = self.runtime_ops.get_module_from_device(
417423
query.device.type, IPEXCustomOpType.VARLEN_ATTENTION, True
@@ -423,6 +429,7 @@ def forward(
423429
out,
424430
seqlen_q,
425431
seqlen_k,
432+
alibi_slopes,
426433
max_seqlen_q,
427434
max_seqlen_k,
428435
pdropout,
@@ -431,6 +438,9 @@ def forward(
431438
is_causal,
432439
return_softmax,
433440
gen_,
441+
window_size_left,
442+
window_size_right,
443+
softcap,
434444
)
435445

436446

intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ def apply_function(
510510
out, # [total_q, num_head, head_size]
511511
seqlen_q, # [batch_size + 1]
512512
seqlen_k, # [batch_size + 1]
513+
alibi_slopes,
513514
max_seqlen_q,
514515
max_seqlen_k,
515516
pdropout=0.0,
@@ -528,6 +529,7 @@ def apply_function(
528529
assert window_size_left == -1, "ipex do not support window_size_left option"
529530
assert window_size_right == -1, "ipex do not support window_size_right option"
530531
assert softcap == -1.0, "ipex do not support softcap option"
532+
assert alibi_slopes is None, "ipex do not support alibi_slopes"
531533

532534
# Repeat kv if it is GQA.
533535
key = cls.repeat_kv(key, int(query.shape[1] / key.shape[1]))
@@ -600,6 +602,7 @@ def forward(
600602
out,
601603
seqlen_q,
602604
seqlen_k,
605+
alibi_slopes,
603606
max_seqlen_q,
604607
max_seqlen_k,
605608
pdropout,
@@ -619,6 +622,7 @@ def forward(
619622
out,
620623
seqlen_q,
621624
seqlen_k,
625+
alibi_slopes,
622626
max_seqlen_q,
623627
max_seqlen_k,
624628
pdropout,

tests/cpu/test_ipex_llm_module.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,93 @@ def selective_scan_ipex(
14251425
self.assertEqual(out_ref, out_ipex, rtol=rtol, atol=atol)
14261426
self.assertEqual(state_ref, state_ipex, rtol=rtolw, atol=atolw)
14271427

1428+
@skipIfNoEINPOS
1429+
def test_varlen_fwd(self):
1430+
HEAD_DIM = [64, 70]
1431+
NUM_HEADS = [(32, 32), (32, 8)]
1432+
BATCH_SIZE = [1, 3, 8]
1433+
DTYPE = [torch.float16]
1434+
USE_ALIBI = [False]
1435+
SEQLEN_RANGE = [10]
1436+
IS_CAUSAL = [False, True]
1437+
WINDOW_SIZE = [(-1, -1)]
1438+
1439+
for (
1440+
head_dim,
1441+
num_heads,
1442+
batch_size,
1443+
dtype,
1444+
use_alibi,
1445+
seqlen_range,
1446+
is_causal,
1447+
window_size,
1448+
) in itertools.product(
1449+
HEAD_DIM,
1450+
NUM_HEADS,
1451+
BATCH_SIZE,
1452+
DTYPE,
1453+
USE_ALIBI,
1454+
SEQLEN_RANGE,
1455+
IS_CAUSAL,
1456+
WINDOW_SIZE,
1457+
):
1458+
torch.manual_seed(15)
1459+
seqlen_list = torch.randint(
1460+
1, seqlen_range, [batch_size], dtype=torch.int32
1461+
)
1462+
max_seqlen = torch.max(seqlen_list)
1463+
cu_seqlen = torch.cumsum(seqlen_list, dim=0)
1464+
num_heads_query, num_heads_kv = num_heads
1465+
cu_seqlen = (
1466+
torch.cat([torch.tensor([0]), cu_seqlen], dim=0)
1467+
.to(torch.int32)
1468+
.to("cpu")
1469+
)
1470+
1471+
query = torch.randn(
1472+
[cu_seqlen[-1], num_heads_query, head_dim], dtype=dtype, device="cpu"
1473+
)
1474+
key = torch.randn(
1475+
[cu_seqlen[-1], num_heads_kv, head_dim], dtype=dtype, device="cpu"
1476+
)
1477+
value = torch.randn(
1478+
[cu_seqlen[-1], num_heads_kv, head_dim], dtype=dtype, device="cpu"
1479+
)
1480+
alibi_slopes = None
1481+
softmax_scale = 1 / math.sqrt(head_dim)
1482+
if use_alibi:
1483+
alibi_slopes = torch.tensor(
1484+
[2 ** (-1 - i) for i in range(num_heads_query)],
1485+
dtype=torch.float,
1486+
device="cpu",
1487+
)
1488+
alibi_slopes = (
1489+
alibi_slopes.unsqueeze(0)
1490+
.expand(batch_size, num_heads_query)
1491+
.contiguous()
1492+
)
1493+
out = query.clone()
1494+
1495+
ipex.llm.functional.varlen_attention(
1496+
query,
1497+
key,
1498+
value,
1499+
out,
1500+
cu_seqlen,
1501+
cu_seqlen,
1502+
alibi_slopes,
1503+
max_seqlen,
1504+
max_seqlen,
1505+
0.0,
1506+
softmax_scale,
1507+
False,
1508+
is_causal,
1509+
False,
1510+
None,
1511+
window_size_left=window_size[0],
1512+
window_size_right=window_size[1],
1513+
)
1514+
14281515

14291516
if __name__ == "__main__":
14301517
test = unittest.main()

0 commit comments

Comments
 (0)