|
| 1 | +import torch |
| 2 | +from sglang.srt.layers.attention.triton_ops.extend_attention import ( |
| 3 | + extend_attention_fwd, |
| 4 | + redundant_attention, |
| 5 | +) |
| 6 | +import triton_kernels_benchmark as benchmark_suit |
| 7 | + |
| 8 | + |
| 9 | +# pylint: disable=unused-argument |
| 10 | +@benchmark_suit.perf_report( |
| 11 | + benchmark_suit.Benchmark( |
| 12 | + # argument names to use as an x-axis for the plot |
| 13 | + x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'], |
| 14 | + x_vals=[ # |
| 15 | + [bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128] |
| 16 | + ] + [ # |
| 17 | + [bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128] |
| 18 | + ] + [ # |
| 19 | + [bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128] |
| 20 | + ], |
| 21 | + line_arg='provider', |
| 22 | + # argument name whose value corresponds to a different line in the plot |
| 23 | + # possible values for `line_arg`` |
| 24 | + line_vals=[ |
| 25 | + 'triton', |
| 26 | + ], |
| 27 | + # label name for the lines |
| 28 | + line_names=[ |
| 29 | + 'Triton', |
| 30 | + ], |
| 31 | + # line styles |
| 32 | + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
| 33 | + ylabel=['GB/s', 'TFlops'], # label name for the y-axis |
| 34 | + plot_name='extended-attn-performance', |
| 35 | + # name for the plot. Used also as a file name for saving the plot. |
| 36 | + args={}, |
| 37 | + )) |
| 38 | +def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider): |
| 39 | + torch.manual_seed(0) |
| 40 | + dtype = torch.bfloat16 |
| 41 | + N_CTX = sum(SEQ_LENS) |
| 42 | + |
| 43 | + b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu') |
| 44 | + b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu') |
| 45 | + b_seq_len = b_seq_len_prefix + b_seq_len_extend |
| 46 | + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() |
| 47 | + |
| 48 | + b_req_idx = torch.arange(BATCH, dtype=torch.int32, device='xpu') |
| 49 | + b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu') |
| 50 | + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) |
| 51 | + b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu') |
| 52 | + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) |
| 53 | + |
| 54 | + kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu') |
| 55 | + kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0) |
| 56 | + kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device='xpu') |
| 57 | + |
| 58 | + for i in range(BATCH): |
| 59 | + kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]) |
| 60 | + |
| 61 | + total_token_num = torch.sum(b_seq_len).item() |
| 62 | + extend_token_num = torch.sum(b_seq_len_extend).item() |
| 63 | + k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, |
| 64 | + device='xpu').normal_(mean=0.1, std=0.2) |
| 65 | + v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, |
| 66 | + device='xpu').normal_(mean=0.1, std=0.2) |
| 67 | + |
| 68 | + k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu') |
| 69 | + v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu') |
| 70 | + q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu') |
| 71 | + for i in range(BATCH): |
| 72 | + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] |
| 73 | + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] |
| 74 | + extend_start = b_start_loc_extend[i] |
| 75 | + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] |
| 76 | + k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer] |
| 77 | + v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer] |
| 78 | + q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype, |
| 79 | + device='xpu').normal_(mean=0.1, std=0.2) |
| 80 | + |
| 81 | + o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu') |
| 82 | + o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu') |
| 83 | + |
| 84 | + b_seq_len_extend = b_seq_len - b_seq_len_prefix |
| 85 | + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() |
| 86 | + qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu') |
| 87 | + qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0) |
| 88 | + |
| 89 | + custom_mask = None |
| 90 | + mask_indptr = None |
| 91 | + |
| 92 | + quantiles = [0.5, 0.0, 1.0] |
| 93 | + if provider == 'triton': |
| 94 | + |
| 95 | + def triton_fn(): |
| 96 | + extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, |
| 97 | + kv_indices, custom_mask, mask_indptr, max_len_extend) |
| 98 | + return o_extend |
| 99 | + |
| 100 | + # TODO: decode attention do not have validation function |
| 101 | + if VALIDATE: |
| 102 | + |
| 103 | + def refer_fn(): |
| 104 | + redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len, |
| 105 | + b_seq_len_prefix, max_len_in_batch) |
| 106 | + return o_redundant |
| 107 | + |
| 108 | + benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer') |
| 109 | + |
| 110 | + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) |
| 111 | + |
| 112 | + else: |
| 113 | + raise NotImplementedError(f'Unsupported provider {provider}') |
| 114 | + |
| 115 | + tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3) |
| 116 | + |
| 117 | + gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3) |
| 118 | + |
| 119 | + return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == '__main__': |
| 123 | + benchmark.run(show_plots=False, print_data=True) |
0 commit comments