6
6
import triton_kernels_benchmark as benchmark_suit
7
7
8
8
9
+ def gen_args (BATCH , N_CTX , Q_HEAD_NUM , KV_HEAD_NUM , HEAD_DIM , dtype , device ):
10
+
11
+ b_seq_len_prefix = torch .randint (1 , N_CTX // 2 , (BATCH , ), dtype = torch .int32 , device = device )
12
+ b_seq_len_extend = torch .randint (1 , N_CTX // 2 , (BATCH , ), dtype = torch .int32 , device = device )
13
+ b_seq_len = b_seq_len_prefix + b_seq_len_extend
14
+ max_len_in_batch = torch .max (b_seq_len , 0 )[0 ].item ()
15
+
16
+ b_req_idx = torch .arange (BATCH , dtype = torch .int32 , device = device )
17
+ b_start_loc = torch .zeros ((BATCH , ), dtype = torch .int32 , device = device )
18
+ b_start_loc [1 :] = torch .cumsum (b_seq_len [:- 1 ], 0 )
19
+ b_start_loc_extend = torch .zeros ((BATCH , ), dtype = torch .int32 , device = device )
20
+ b_start_loc_extend [1 :] = torch .cumsum (b_seq_len_extend [:- 1 ], 0 )
21
+
22
+ kv_indptr = torch .zeros ((BATCH + 1 , ), dtype = torch .int32 , device = device )
23
+ kv_indptr [1 :BATCH + 1 ] = torch .cumsum (b_seq_len_prefix [:BATCH ], dim = 0 )
24
+ kv_indices = torch .zeros ((b_seq_len_prefix .sum ().item (), ), dtype = torch .int32 , device = device )
25
+
26
+ for i in range (BATCH ):
27
+ kv_indices [kv_indptr [i ]:kv_indptr [i + 1 ]] = torch .arange (b_start_loc [i ], b_start_loc [i ] + b_seq_len_prefix [i ])
28
+
29
+ total_token_num = torch .sum (b_seq_len ).item ()
30
+ extend_token_num = torch .sum (b_seq_len_extend ).item ()
31
+ k_buffer = torch .empty ((total_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype ,
32
+ device = device ).normal_ (mean = 0.1 , std = 0.2 )
33
+ v_buffer = torch .empty ((total_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype ,
34
+ device = device ).normal_ (mean = 0.1 , std = 0.2 )
35
+
36
+ k_extend = torch .empty ((extend_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = device )
37
+ v_extend = torch .empty ((extend_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = device )
38
+ q_extend = torch .empty ((extend_token_num , Q_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = device )
39
+ for i in range (BATCH ):
40
+ extend_start_in_buffer = b_start_loc [i ] + b_seq_len_prefix [i ]
41
+ extend_end_in_buffer = b_start_loc [i ] + b_seq_len [i ]
42
+ extend_start = b_start_loc_extend [i ]
43
+ extend_end = b_start_loc_extend [i ] + b_seq_len_extend [i ]
44
+ k_extend [extend_start :extend_end ] = k_buffer [extend_start_in_buffer :extend_end_in_buffer ]
45
+ v_extend [extend_start :extend_end ] = v_buffer [extend_start_in_buffer :extend_end_in_buffer ]
46
+ q_extend [extend_start :extend_end ] = torch .empty ((b_seq_len_extend [i ], Q_HEAD_NUM , HEAD_DIM ), dtype = dtype ,
47
+ device = device ).normal_ (mean = 0.1 , std = 0.2 )
48
+
49
+ o_extend = torch .empty ((extend_token_num , Q_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = device )
50
+ o_redundant = torch .empty ((extend_token_num , Q_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = device )
51
+
52
+ b_seq_len_extend = b_seq_len - b_seq_len_prefix
53
+ max_len_extend = torch .max (b_seq_len_extend , 0 )[0 ].item ()
54
+ qo_indptr = torch .zeros ((BATCH + 1 , ), dtype = torch .int32 , device = device )
55
+ qo_indptr [1 :BATCH + 1 ] = torch .cumsum (b_seq_len_extend [:BATCH ], dim = 0 )
56
+
57
+ params = []
58
+ params .append ((q_extend , k_extend , v_extend , o_extend , o_redundant ))
59
+ params .append ((k_buffer , v_buffer ))
60
+ params .append ((qo_indptr , kv_indptr , kv_indices , max_len_extend ))
61
+ params .append ((b_req_idx , b_start_loc , b_seq_len , b_seq_len_prefix , max_len_in_batch ))
62
+ return params
63
+
64
+
9
65
# pylint: disable=unused-argument
10
66
@benchmark_suit .perf_report (
11
67
benchmark_suit .Benchmark (
12
68
# argument names to use as an x-axis for the plot
13
- x_names = ['BATCH ' , 'SEQ_LENS' , 'Q_HEAD_NUM ' , 'KV_HEAD_NUM ' , 'HEAD_DIM ' , 'MODE' , 'VALIDATE' ],
69
+ x_names = ['B ' , 'SEQ_LENS' , 'H_Q ' , 'H_KV ' , 'D ' , 'MODE' , 'VALIDATE' ],
14
70
x_vals = [ #
15
71
[bs , [1024 , 128 , 512 ], 32 , 8 , 128 , 'fwd' , True ] for bs in [1 , 16 , 32 , 64 , 128 ]
16
72
] + [ #
35
91
# name for the plot. Used also as a file name for saving the plot.
36
92
args = {},
37
93
))
38
- def benchmark (BATCH , SEQ_LENS , Q_HEAD_NUM , KV_HEAD_NUM , HEAD_DIM , MODE , VALIDATE , provider ):
94
+ def benchmark (B , SEQ_LENS , H_Q , H_KV , D , MODE , VALIDATE , provider ):
39
95
torch .manual_seed (0 )
96
+
40
97
dtype = torch .bfloat16
41
98
N_CTX = sum (SEQ_LENS )
42
99
43
- b_seq_len_prefix = torch .randint (1 , N_CTX // 2 , (BATCH , ), dtype = torch .int32 , device = 'xpu' )
44
- b_seq_len_extend = torch .randint (1 , N_CTX // 2 , (BATCH , ), dtype = torch .int32 , device = 'xpu' )
45
- b_seq_len = b_seq_len_prefix + b_seq_len_extend
46
- max_len_in_batch = torch .max (b_seq_len , 0 )[0 ].item ()
47
-
48
- b_req_idx = torch .arange (BATCH , dtype = torch .int32 , device = 'xpu' )
49
- b_start_loc = torch .zeros ((BATCH , ), dtype = torch .int32 , device = 'xpu' )
50
- b_start_loc [1 :] = torch .cumsum (b_seq_len [:- 1 ], 0 )
51
- b_start_loc_extend = torch .zeros ((BATCH , ), dtype = torch .int32 , device = 'xpu' )
52
- b_start_loc_extend [1 :] = torch .cumsum (b_seq_len_extend [:- 1 ], 0 )
53
-
54
- kv_indptr = torch .zeros ((BATCH + 1 , ), dtype = torch .int32 , device = 'xpu' )
55
- kv_indptr [1 :BATCH + 1 ] = torch .cumsum (b_seq_len_prefix [:BATCH ], dim = 0 )
56
- kv_indices = torch .zeros ((b_seq_len_prefix .sum ().item (), ), dtype = torch .int32 , device = 'xpu' )
57
-
58
- for i in range (BATCH ):
59
- kv_indices [kv_indptr [i ]:kv_indptr [i + 1 ]] = torch .arange (b_start_loc [i ], b_start_loc [i ] + b_seq_len_prefix [i ])
60
-
61
- total_token_num = torch .sum (b_seq_len ).item ()
62
- extend_token_num = torch .sum (b_seq_len_extend ).item ()
63
- k_buffer = torch .empty ((total_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype ,
64
- device = 'xpu' ).normal_ (mean = 0.1 , std = 0.2 )
65
- v_buffer = torch .empty ((total_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype ,
66
- device = 'xpu' ).normal_ (mean = 0.1 , std = 0.2 )
67
-
68
- k_extend = torch .empty ((extend_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = 'xpu' )
69
- v_extend = torch .empty ((extend_token_num , KV_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = 'xpu' )
70
- q_extend = torch .empty ((extend_token_num , Q_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = 'xpu' )
71
- for i in range (BATCH ):
72
- extend_start_in_buffer = b_start_loc [i ] + b_seq_len_prefix [i ]
73
- extend_end_in_buffer = b_start_loc [i ] + b_seq_len [i ]
74
- extend_start = b_start_loc_extend [i ]
75
- extend_end = b_start_loc_extend [i ] + b_seq_len_extend [i ]
76
- k_extend [extend_start :extend_end ] = k_buffer [extend_start_in_buffer :extend_end_in_buffer ]
77
- v_extend [extend_start :extend_end ] = v_buffer [extend_start_in_buffer :extend_end_in_buffer ]
78
- q_extend [extend_start :extend_end ] = torch .empty ((b_seq_len_extend [i ], Q_HEAD_NUM , HEAD_DIM ), dtype = dtype ,
79
- device = 'xpu' ).normal_ (mean = 0.1 , std = 0.2 )
80
-
81
- o_extend = torch .empty ((extend_token_num , Q_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = 'xpu' )
82
- o_redundant = torch .empty ((extend_token_num , Q_HEAD_NUM , HEAD_DIM ), dtype = dtype , device = 'xpu' )
83
-
84
- b_seq_len_extend = b_seq_len - b_seq_len_prefix
85
- max_len_extend = torch .max (b_seq_len_extend , 0 )[0 ].item ()
86
- qo_indptr = torch .zeros ((BATCH + 1 , ), dtype = torch .int32 , device = 'xpu' )
87
- qo_indptr [1 :BATCH + 1 ] = torch .cumsum (b_seq_len_extend [:BATCH ], dim = 0 )
88
-
100
+ params = gen_args (B , N_CTX , H_Q , H_KV , D , dtype , 'xpu' )
101
+ q_extend , k_extend , v_extend , o_extend , o_redundant = params [0 ]
102
+ k_buffer , v_buffer = params [1 ]
103
+ qo_indptr , kv_indptr , kv_indices , max_len_extend = params [2 ]
104
+ b_req_idx , b_start_loc , b_seq_len , b_seq_len_prefix , max_len_in_batch = params [3 ]
89
105
custom_mask = None
90
106
mask_indptr = None
91
107
@@ -97,7 +113,6 @@ def triton_fn():
97
113
kv_indices , custom_mask , mask_indptr , max_len_extend )
98
114
return o_extend
99
115
100
- # TODO: decode attention do not have validation function
101
116
if VALIDATE :
102
117
103
118
def refer_fn ():
@@ -112,9 +127,8 @@ def refer_fn():
112
127
else :
113
128
raise NotImplementedError (f'Unsupported provider { provider } ' )
114
129
115
- tflops = lambda ms : 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX ) * N_CTX * HEAD_DIM * (1e-12 ) / (ms * 1e-3 )
116
-
117
- gbps = lambda ms : 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX ) * HEAD_DIM * 2 * (1e-9 ) / (ms * 1e-3 )
130
+ tflops = lambda ms : 2 * B * (H_Q + H_KV * N_CTX ) * N_CTX * D * (1e-12 ) / (ms * 1e-3 )
131
+ gbps = lambda ms : 2 * B * (H_Q + H_KV * N_CTX ) * D * 2 * (1e-9 ) / (ms * 1e-3 )
118
132
119
133
return (gbps (mean ), gbps (max_ms ), gbps (min_ms )), (tflops (mean ), tflops (max_ms ), tflops (min_ms )), cv
120
134
0 commit comments