diff --git a/setup.py b/setup.py index 7e60acbfa8..b05ff696b3 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,10 @@ def read_version(file_path="version.txt"): and platform.system() == "Darwin" ) +use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux" + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + version_prefix = read_version() # Version is version.dev year month date if using nightlies and version if not version = ( @@ -284,6 +288,17 @@ def get_extensions(): ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"] ) + if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7: + if torch._C._cpu._is_avx512_supported(): + extra_compile_args["cxx"].extend( + [ + "-DCPU_CAPABILITY_AVX512", + "-march=native", + "-mfma", + "-fopenmp", + ] + ) + if debug_mode: extra_compile_args["cxx"].append("-g") if "nvcc" in extra_compile_args: @@ -305,6 +320,12 @@ def get_extensions(): # Collect C++ source files sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) + if IS_WINDOWS: + # Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C + excluded_sources = list( + glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True) + ) + sources = [s for s in sources if s not in excluded_sources] extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list( diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_int8_sdpa_fusion.py new file mode 100644 index 0000000000..9596e71a7a --- /dev/null +++ b/test/prototype/inductor/test_int8_sdpa_fusion.py @@ -0,0 +1,217 @@ +import itertools + +import pytest +import torch +import torch.utils.checkpoint +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.test_case import TestCase, run_tests +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CPU +from torch.utils.cpp_extension import IS_WINDOWS + +import torchao +from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + + +class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + has_mask, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) + self.softmax = torch.nn.Softmax(dim=-1) + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) + self.dropout = torch.nn.Dropout(0) + self.has_mask = has_mask + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute([0, 2, 1, 3]) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + if self.has_mask and mask.dtype != scores.dtype: + scores = scores + mask + attention = self.softmax(scores) + attention = self.dropout(attention) + context_layer = torch.matmul(attention, v) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view( + context_layer.size()[:-2] + (self.all_head_size,) + ) + return self.dense(context_layer) + + +class TestSDPAPatternRewriterTemplate(TestCase): + def _clone_inputs(self, inputs): + def clone(x): + if not isinstance(x, torch.Tensor): + return x + return x.clone() + + return [clone(x) for x in inputs] + + def _check_common( + self, + dot_prod_attention, + args1=None, + contains=True, + atol=1e-5, + has_fuse_pattern=True, + has_dropout=False, + check_train=True, + override_check_equal=False, + dtype=torch.float, + rtol=1.3e-6, + ): + if args1 is None: + tensor_shape = (4, 2, 16, 32) + args1 = [ + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + torch.randn(tensor_shape, device=self.device, dtype=dtype), + ] + else: + args1 = list(args1) + args2 = self._clone_inputs(args1) + + for training in [False, True] if check_train else [False]: + for x in itertools.chain(args1[:], args2[:]): + if isinstance(x, torch.Tensor) and x.is_floating_point(): + x.requires_grad = training + + dropout_arg = [training] if has_dropout else [] + torch.manual_seed(1234) + result1 = dot_prod_attention(*(args1 + dropout_arg)) + + counters.clear() + torch.manual_seed(1234) + compiled_model = torch.compile(dot_prod_attention, fullgraph=True) + result2, source_code = run_and_get_code( + compiled_model, + *(args2 + dropout_arg), + ) + source_code = "\n".join(source_code) + if has_fuse_pattern: + self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) + if contains: + # many of the patterns get re-expanded in dispatcher + self.assertIn( + "torchao.scaled_dot_product_int8", + source_code, + ) + + # some tests configured with very low dropout where we still want to check equality + if not has_dropout or override_check_equal: + self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) + + if training: + result1.sum().backward() + result2.sum().backward() + for arg1, arg2 in zip(args1, args2): + if ( + isinstance(arg1, torch.Tensor) + and arg1.is_floating_point() + and (not has_dropout or override_check_equal) + ): + self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) + + @skipIfRocm + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + ) + @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") + @config.patch({"freezing": True}) + def _test_sdpa_int8_rewriter(self): + from torch.export import export_for_training + + import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( + X86InductorQuantizer, + ) + + # pattern is different for bs=1 + torch.manual_seed(1234) + for dtype, has_mask, bs in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [56, 1] + ): + seqlen, numhead, headsize = 197, 16, 64 + mod = SelfAttnLikeModule( + input_dim=headsize * numhead, + has_mask=has_mask, + num_attention_heads=numhead, + attention_head_size=headsize, + ).eval() + inputs = ( + torch.randn( + (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype + ), + torch.randn((bs, 1, 1, seqlen), device=self.device) + if has_mask + else None, + ) + enable_autocast = dtype == torch.bfloat16 + with ( + torch.no_grad(), + torch.amp.autocast( + self.device, enabled=enable_autocast, dtype=torch.bfloat16 + ), + ): + _int8_sdpa_init() + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + export_model = export_for_training( + mod, + inputs, + strict=True, + ).module() + prepare_model = prepare_pt2e(export_model, quantizer) + prepare_model(*inputs) + convert_model = convert_pt2e(prepare_model) + torchao.quantization.pt2e.move_exported_model_to_eval(convert_model) + self._check_common( + convert_model, args1=inputs, check_train=False, atol=1.0 + ) + + +if HAS_CPU: + + class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): + device = "cpu" + test_sdpa_int8_rewriter_cpu = ( + TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter + ) + + +if __name__ == "__main__": + if IS_LINUX: + run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 1cdce2cd81..646d1c76af 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import itertools +import math import sys import pytest @@ -14,6 +15,7 @@ parametrize, ) from torch.testing._internal.optests import opcheck +from torch.utils.cpp_extension import IS_WINDOWS import torchao from torchao.dtypes.floatx import from_scaled_tc_floatx @@ -23,7 +25,11 @@ ) from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_7, + compute_max_diff, +) if torch.version.hip is not None: pytest.skip("Skipping the test in ROCm", allow_module_level=True) @@ -109,6 +115,135 @@ def test_quant_llm_linear_correctness( rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 assert relative_error < rtol + def _scaled_dot_product_int8_op_ref( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0, + is_causal=False, + q_scale=1.0, + q_zp=0, + k_scale=1.0, + k_zp=0, + v_scale=1.0, + v_zp=0, + a_scale=1.0, + a_zp=0, + o_scale=1.0, + o_zp=0, + ): + q = (q.to(torch.float) - q_zp) * q_scale + k = (k.to(torch.float) - k_zp) * k_scale + v = (v.to(torch.float) - v_zp) * v_scale + scale_factor = 1 / math.sqrt(q.size(-1)) + attn = q @ k.transpose(-2, -1) + attn = attn * scale_factor + if attn_mask is not None: + attn = attn + attn_mask.to(torch.float) + attn_max = attn.max(dim=-1, keepdim=True).values + attn = attn - attn_max + attn = torch.exp(attn) + attn_sum = torch.sum(attn, dim=-1, keepdim=True) + attn = attn / attn_sum + attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255) + attn = (attn - a_zp) * a_scale + out = attn @ v + out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) + return out.to(torch.uint8) + + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + ) + @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") + @parametrize("batch_size", [56, 120]) + @parametrize("n_head", [2, 16]) + @parametrize("q_seq_len", [18, 89]) + @parametrize("kv_seq_len", [100, 253]) + @parametrize("head_dim", [32, 64]) + @parametrize("mask_dtype", [None, torch.float32, torch.bfloat16]) + def test_scaled_dot_product_int8_op( + self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype + ): + torch.manual_seed(1234) + device = "cpu" + q_scale = float(1.7907238006591797) + q_zp = int(127) + k_scale = float(1.8039721250534058) + k_zp = int(125) + v_scale = float(1.839004635810852) + v_zp = int(127) + a_scale = float(0.003919653594493866) + a_zp = int(120) + o_scale = float(1.8191684484481812) + o_zp = int(128) + q_shape = [batch_size, q_seq_len, n_head, head_dim] + kv_shape = [batch_size, kv_seq_len, n_head, head_dim] + mask_shape = [batch_size, 1, 1, kv_seq_len] + q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 + k = ( + torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + * 100 + ) + v = ( + torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + * 100 + ) + q = q.to(torch.uint8) + k = k.to(torch.uint8) + v = v.to(torch.uint8) + attn_mask = ( + torch.randn(mask_shape, dtype=mask_dtype, device=device) + if mask_dtype is not None + else None + ) + q2, k2, v2, attn_mask_2 = ( + q.clone(), + k.clone(), + v.clone(), + attn_mask.clone() if mask_dtype is not None else None, + ) + + math_ref = self._scaled_dot_product_int8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + q_zp=q_zp, + k_scale=k_scale, + k_zp=k_zp, + v_scale=v_scale, + v_zp=v_zp, + a_scale=a_scale, + a_zp=a_zp, + o_scale=o_scale, + o_zp=o_zp, + ) + actual = torch.ops.torchao.scaled_dot_product_int8( + q, + k, + v, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + q_zp=q_zp, + k_scale=k_scale, + k_zp=k_zp, + v_scale=v_scale, + v_zp=v_zp, + a_scale=a_scale, + a_zp=a_zp, + o_scale=o_scale, + o_zp=o_zp, + ) + + self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) + instantiate_parametrized_tests(TestOps) diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/int8_sdpa.cpp new file mode 100644 index 0000000000..36cd24ab5e --- /dev/null +++ b/torchao/csrc/cpu/int8_sdpa.cpp @@ -0,0 +1,1907 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include + +namespace torchao { + +namespace { + +inline double calculate_scale( + const at::Tensor& query, + double scale) { + return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale; +} + +#ifdef CPU_CAPABILITY_AVX512 + +template +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto data_vec = at::vec::Vectorized(val); + int64_t d = 0; + for (; d < size - (size % vec_size); d += vec_size) { + data_vec.store(data + d); + } + if (d < size) { + data_vec.store(data + d, size - d); + } +} + +void reshape_attn_mask_to_4d( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + // Support mask shapes: + // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) + // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) + // Guaranteed in check_attn_mask_shape + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + +// TODO: Use at::native::_store instead when it supports Half. +template +inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { + src.store(dst, size); +} + +template +inline typename std::enable_if_t || std::is_same_v, void> +_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { + auto res = at::vec::convert(src); + res.store(dst, size); +} + +/* +1. dequant +2. add mask +3. max reduce for softmax +*/ +template +inline void _dequant_mask_max_fusion_kernel( + const int32_t* in, + const mask_t* mask_ptr, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldm, // leading dimension mask + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + const mask_t* mask_data_ptr = mask_ptr + row * ldm; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col); + auto tmp7 = at::vec::convert(tmp6); + auto tmp8 = tmp5 + tmp7; + vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp8); + _store(tmp_out + col, tmp8); + } + if (col < N) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col, N - col); + auto tmp7 = at::vec::convert(tmp6); + auto tmp8 = tmp5 + tmp7; + _store(tmp_out + col, tmp8, N - col); + vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp8), N - col); + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); + } +} + +/* +1. dequant +2. max reduce for softmax +*/ +inline void _dequant_max_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta, // zp_a*zp_b*k + const float& alpha, // scale_a*scale_b*scale_sdpa + float* out, + float* sfm_max_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_beta = at::vec::Vectorized(beta); + auto vec_alpha = at::vec::Vectorized(alpha); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + float* tmp_out = out + row * ldo; + float tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp5); + _store(tmp_out + col, tmp5); + } + if (col < N) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + _store(tmp_out + col, tmp5, N - col); + vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp5), N - col); + } + sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); + } +} + +/* +1. Softmax: sub max, exp, sum reduce, div sum +2. quant +3. sum for attention +*/ +template +inline void _sub_exp_sum_div_quant_sum_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const int32_t& beta2, // zp_b + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr, + int32_t* sum_a_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + _store(tmp_out + col, tmp2, kvBlockSize - col); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); + } + sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4); + auto tmp6 = at::vec::convert(tmp4); + vec_tmp_sum += tmp6; + } + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4, kvBlockSize - col); + auto tmp6 = at::vec::convert(tmp4); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); + } + sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2; + // set zero + col = kvBlockSize; + for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + if (col < av_gemm_K) { + _store(tmp_out + col, vec_zero, av_gemm_K - col); + } + } + } +} + +/* +1. Softmax: sub max, exp, sum reduce, div sum +2. quant +*/ +template +inline void _sub_exp_sum_div_quant_fusion_kernel( + const float* in, + const int64_t& M, + const int64_t& N_step, + const int64_t& NSlice, + const int& ldi, + const int& ldo, + const int& kvSize, + const int& rndkvSplitSize, + const int& av_gemm_K, + const int32_t& beta1, // zp_a + const float& alpha, // scale_a + float* local, + scalar_t* out, + float* sfm_max_ptr, + float* sfm_sum_ptr) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + scalar_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + float beta1_float = (float) beta1; + auto vec_beta1 = at::vec::Vectorized(beta1_float); + for (int64_t row = 0; row < M; ++row) { + auto sfm_max = sfm_max_ptr[row]; + auto vec_max = at::vec::Vectorized(sfm_max); + // sub max, exp, sum reduce + const float* qk_block_data = in + row * rndkvSplitSize; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + const float* tmp_in = qk_block_data + l * ldi; + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float* tmp_out = local + n; + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + _store(tmp_out + col, tmp2); + } + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); + _store(tmp_out + col, tmp2, kvBlockSize - col); + } + sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); + } + // div sum, sum for attention + auto sum_scale = 1 / sfm_sum_ptr[row] / alpha; + auto vec_sum_scale = at::vec::Vectorized(sum_scale); + scalar_t* qk_reduced_block_data = out + row * av_gemm_K; + for (int64_t l = 0; l < NSlice; l ++) { + int64_t n = l * N_step; + int64_t kvBlockSize = std::min(N_step, kvSize - n); + float* tmp_in = local + n; + scalar_t* tmp_out = qk_reduced_block_data + l * ldo; + long col = 0; + for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4); + } + if (col < kvBlockSize) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); + auto tmp1 = tmp0 * vec_sum_scale; + auto tmp2 = tmp1.round(); + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp4, kvBlockSize - col); + } + // set zero + col = kvBlockSize; + for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { + _store(tmp_out + col, vec_zero); + } + if (col < av_gemm_K) { + _store(tmp_out + col, vec_zero, av_gemm_K - col); + } + } + } +} + +/* +1. dequant +2. quant +*/ +template +inline void _dequant_quant_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int32_t* sum_b_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta1, // zp_a*zp_b*k + const int32_t& beta2, // zp_c + const float& alpha, // scale_a*scale_b/scale_c + scalar_t* out) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_beta1 = at::vec::Vectorized(beta1); + auto vec_alpha = at::vec::Vectorized(alpha); + float beta2_float = (float) beta2; + auto vec_beta2 = at::vec::Vectorized(beta2_float); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + scalar_t* tmp_out = out + row * ldo; + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8); + } + if (col < N) { + auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp1 = tmp0 - vec_sum_b; + auto tmp2 = tmp1 - vec_sum_a; + auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8, N - col); + } + } +} + +/* +1. dequant +2. quant +*/ +template +inline void _dequant_quant_fusion_kernel( + const int32_t* in, + const int32_t* sum_a_ptr, + const int& M, + const int& N, + const int& ldi, + const int& ldo, + const int32_t& beta2, // zp_c + const float& alpha, // scale_a*scale_b/scale_c + scalar_t* out) { + const int32_t vec_size = at::vec::Vectorized::size(); + float min_val = 0; + float max_val = 255; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + // auto vec_beta1 = at::vec::Vectorized(beta1); + auto vec_alpha = at::vec::Vectorized(alpha); + float beta2_float = (float) beta2; + auto vec_beta2 = at::vec::Vectorized(beta2_float); + for (long row = 0; row < M; row += 1) { + auto sum_a = sum_a_ptr[row]; + auto vec_sum_a = at::vec::Vectorized(sum_a); + const int32_t* tmp_in = in + row * ldi; + scalar_t* tmp_out = out + row * ldo; + long col = 0; + for (; col < vec_size * (N / vec_size); col += vec_size) { + auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col); + auto tmp3 = tmp1 - vec_sum_a; + // auto tmp3 = tmp2 + vec_beta1; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8); + } + if (col < N) { + auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col, N - col); + auto tmp3 = tmp1 - vec_sum_a; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = tmp4 * vec_alpha; + auto tmp6 = tmp5.round(); + auto tmp7 = tmp6 + vec_beta2; + auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); + _store(tmp_out + col, tmp8, N - col); + } + } +} + +template +inline void _int_sum_b_contiguous_kernel_helper( + const scalar_t* in, + int32_t* out, + const int& N, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + int32_t tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + long i = 0; + for (; i < vec_size * (N / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(in + i); + auto tmp1 = at::vec::convert(tmp0); + vec_tmp_sum = vec_tmp_sum + tmp1; + } + if (i < N) { + auto tmp0 = at::vec::Vectorized::loadu(in + i, N - i); + auto tmp1 = at::vec::convert(tmp0); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp1, N - i); + } + out[0] = vec_tmp_sum.reduce_add() * scale; +} + +// reduce along dim b for shape [a, b], with sum shape [a] +template +inline void _int_sum_b_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + for (long r = 0; r < M; r += 1) { + _int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); + } +} + +// reduce along dim a for shape [a, b], with sum shape [b] +template +inline void _int_sum_a_contiguous_kernel( + const scalar_t* in, + int32_t* out, + const int& M, + const int& N, + const int& ld, + const int32_t& scale) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + // initialization with 0 + int32_t zero = 0; + auto vec_zero = at::vec::Vectorized(zero); + long i = 0; + for (; i < vec_size * (M / vec_size); i += vec_size) { + _store(out + i, vec_zero); + } + if (i < M) { + _store(out + i, vec_zero, M - i); + } + // sum + for (long j = 0; j < N; j++) { + const scalar_t* tmp_in = in + j * ld; + long k = 0; + for (; k < vec_size * (M / vec_size); k += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k); + auto tmp1 = at::vec::Vectorized::loadu(out + k); + auto tmp2 = at::vec::convert(tmp0); + auto tmp3 = tmp1 + tmp2; + _store(out + k, tmp3); + } + if (k < M) { + auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k, M - k); + auto tmp1 = at::vec::Vectorized::loadu(out + k, M - k); + auto tmp2 = at::vec::convert(tmp0); + auto tmp3 = tmp1 + tmp2; + _store(out + k, tmp3, M - k); + } + } + // scale + i = 0; + for (; i < vec_size * (M / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(out + i); + auto tmp1 = tmp0 * vec_scale; + _store(out + i, tmp1); + } + if (i < M) { + auto tmp0 = at::vec::Vectorized::loadu(out + i, M - i); + auto tmp1 = tmp0 * vec_scale; + _store(out + i, tmp1, M - i); + } +} + +// do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] +template +inline void do_transpose( + scalar_t* src, + scalar_t* dst, + int64_t in_rows, + int64_t in_cols, + int64_t ldi, + int64_t ldo) { + for (int64_t r=0; r [prows, pcols] +template +inline void pad_remain_row_col( + scalar_t* value_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + auto psize = pcols - cols; + if (psize == 0 && prows == rows) { + return; + } + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + if (psize > 0) { + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < psize - (psize % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + cols + j); + } + if (j < psize) { + pad.store(value_ptr + i * ldi + cols + j, psize - j); + } + } + } + + for (int i = rows; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(value_ptr + i * ldi + j); + } + if (j < pcols) { + pad.store(value_ptr + i * ldi + j, pcols - j); + } + } +} + +// copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] +template +inline void copy_value_with_pad( + scalar_t* value_ptr, + scalar_t* dst_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi, + scalar_t pad_val=0) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto pad = at::vec::Vectorized(pad_val); + int i = 0; + for (; i < rows; i++) { + int j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + int pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + pad.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + pad.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + + // row padding + for (; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + pad.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + pad.store(dst_ptr + i * pcols + j, pcols - j); + } + + } + +} + +// UINT8 - one parallel loop with u8u8s32 GEMM +template = 0> +inline typename std::enable_if_t, void> +sdpa_int8_fused_kernel_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + std::optional attention_mask, + double scale, + float q_scale, + int32_t q_zp, + float k_scale, + int32_t k_zp, + float v_scale, + int32_t v_zp, + float a_scale, + int32_t a_zp, + float o_scale, + int32_t o_zp) { + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + using accum_t = float; + accum_t scaling_factor = calculate_scale(query, scale); + int block_64 = 64; + auto u8_dt = at::ScalarType::Byte; + + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + bool has_attn_mask = attention_mask.has_value() && attention_mask.value().numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask.value(), batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (has_attn_mask && attention_mask.value().size(0) > 1) + ? attention_mask.value().stride(0) + : 0; + int64_t mStrideH = + (has_attn_mask && attention_mask.value().size(1) > 1) + ? attention_mask.value().stride(1) + : 0; + int64_t mStrideM = + (has_attn_mask && attention_mask.value().size(2) > 1) + ? attention_mask.value().stride(2) + : 0; + int64_t mStrideN = + (has_attn_mask && attention_mask.value().size(3) > 1) + ? attention_mask.value().stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.has_value() + ? attention_mask.value().data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + bool headSize_mul64 = headSize % 64 == 0; + int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; + int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; + + int64_t total_size_uint8_per_thread = + /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + + /* qk_local */ kvSlice * av_gemm_K * 4 + + /* qk_reduce */ kvSlice * qk_reduce_strideL + + /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + + /* dst_s32 */ qSplitSize * rndHeadSize * 4 + + /* softmax_sum */ qSplitSize * 4 + + /* query_sum */ qSplitSize * 4 + + /* attention_sum */ qSplitSize * 4 + + /* softmax max */ qSplitSize * 4 + + /* query_padding_data */ qSplitSize * qk_gemm_K + + /* key_sum */ kvSize * 4 + + /* value_sum */ headSize * 4 + + /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* value_t_reorder */ kvSlice * v_reorder_strideL; + + at::Tensor total_buf = at::empty( + {num_thread, total_size_uint8_per_thread}, + query.options()); + scalar_t* total_buf_data = total_buf.data_ptr(); + + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qk_reduce_strideL; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * qk_gemm_K; + + int32_t* k_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += kvSize * 4; + int32_t* v_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += headSize * 4; + scalar_t* key_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qk_gemm_K * rndkvSize; + scalar_t* value_reorder_ptr = reinterpret_cast(total_buf_ptr + offset); + + uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + // sum k and v + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + + // transpose and packing + for (int64_t n = 0; n < kvSize; n += kvSplitSize) { + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + bool istail = kvBlockSize - b < block_64; + int64_t trans_rows = istail ? kvBlockSize - b : block_64; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + trans_rows, + headSize, + kStrideN, + block_64); + if (!headSize_mul64 || istail) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + trans_rows, + qk_gemm_K, + block_64, + block_64 + ); + } + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + value_reorder_ptr + n * rndHeadSize + + av_gemm_K * b); + } + } + + // sdpa core + for (int64_t k = 0; k < qSlice; k++) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + int64_t num_keys = + is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + qk_gemm_K, + qStrideM); + // sum q + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + q_sum_ptr, qBlockSize, headSize, qStrideM, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // Calculate q @ k.T + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + key_reorder_ptr + n * qk_gemm_K + + b * qk_gemm_K, + qk_s32_data + b); + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvSplitSize, //ldi + mStrideM, //ldm + rndkvSplitSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvSplitSize, //ldi + rndkvSplitSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + for (int64_t b = 0; b < headSize; b += block_64) { + auto value_reorder_b = value_reorder_ptr + b * av_gemm_K; + auto dst_s32_b = dst_s32_data + b; + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + av_gemm_K, // lda + rndHeadSize, //block_64, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + value_reorder_b + s * v_reorder_strideL, + dst_s32_b); + } + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + if (a_zp == 0) { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } else { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + // Once all computations are done, need to release HW context. + at::native::cpublas::brgemm_release(); +} + +// UINT8 - several parallel loops with u8u8s32 GEMM +template = 0> +inline typename std::enable_if_t, void> +sdpa_int8_fused_kernel_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + std::optional attention_mask, + double scale, + float q_scale, + int32_t q_zp, + float k_scale, + int32_t k_zp, + float v_scale, + int32_t v_zp, + float a_scale, + int32_t a_zp, + float o_scale, + int32_t o_zp) { + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + using accum_t = float; + accum_t scaling_factor = calculate_scale(query, scale); + int block_64 = 64; + auto u8_dt = at::ScalarType::Byte; + + // Sizes + TORCH_CHECK( + (query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_sdpa: Q/K/V should have the same head size"); + TORCH_CHECK( + kv_split_size % block_64 == 0, "kv_split_size is not divisble by ", block_64); + + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + bool has_attn_mask = attention_mask.has_value() && attention_mask.value().numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attention_mask.value(), batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (has_attn_mask && attention_mask.value().size(0) > 1) + ? attention_mask.value().stride(0) + : 0; + int64_t mStrideH = + (has_attn_mask && attention_mask.value().size(1) > 1) + ? attention_mask.value().stride(1) + : 0; + int64_t mStrideM = + (has_attn_mask && attention_mask.value().size(2) > 1) + ? attention_mask.value().stride(2) + : 0; + int64_t mStrideN = + (has_attn_mask && attention_mask.value().size(3) > 1) + ? attention_mask.value().stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; + int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndkvSize = kv_split_size > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; + + bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; + int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; + int av_gemm_K = kvSplitSize + av_gemm_K_padding; + + // Data ptrs + scalar_t* q_data = query.data_ptr(); + scalar_t* k_data = key.data_ptr(); + scalar_t* v_data = value.data_ptr(); + mask_t* mask_data = attention_mask.has_value() + ? attention_mask.value().data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + + bool headSize_mul64 = headSize % 64 == 0; + int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; + int qk_gemm_K = headSize + qk_gemm_K_padding; + + int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; + int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; + + int64_t total_size_uint8_per_thread = + /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + + /* qk_local */ kvSlice * av_gemm_K * 4 + + /* qk_reduce */ kvSlice * qk_reduce_strideL + + /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + + /* dst_s32 */ qSplitSize * rndHeadSize * 4 + + /* softmax_sum */ qSplitSize * 4 + + /* query_sum */ qSplitSize * 4 + + /* attention_sum */ qSplitSize * 4 + + /* softmax max */ qSplitSize * 4 + + /* query_padding_data */ qSplitSize * qk_gemm_K; + + at::Tensor total_buf = at::empty( + {num_thread, total_size_uint8_per_thread}, + query.options()); + scalar_t* total_buf_data = total_buf.data_ptr(); + + int64_t kv_sum_size_per_BH = + /* key_sum */ kvSize + + /* value_sum */ headSize; + + at::Tensor kv_sum_buf = at::empty( + {batchSize, num_head, kv_sum_size_per_BH}, + query.options().dtype(at::kInt)); + int32_t* kv_sum_buf_data = kv_sum_buf.data_ptr(); + + int64_t kv_reorder_size_per_BH = + /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* value_t_reorder */ kvSlice * v_reorder_strideL; + + at::Tensor kv_reorder_buf = at::empty( + {batchSize, num_head, kv_reorder_size_per_BH}, + query.options()); + scalar_t* kv_reorder_buf_data = kv_reorder_buf.data_ptr(); + scalar_t* key_reorder_ptr = kv_reorder_buf_data; + scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; + + // sum k and v + at::parallel_for( + 0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head); + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; + if (q_zp == 0) { + fill_stub(k_sum_ptr, static_cast(0), kvSize); + } else { + _int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + k_sum_ptr, + kvSize, headSize, kStrideN, q_zp); + } + if (a_zp == 0) { + fill_stub(v_sum_ptr, static_cast(0), headSize); + } else { + _int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + v_sum_ptr, + headSize, kvSize, vStrideN, a_zp); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head); + } + }); + + // transpose and packing + at::parallel_for( + 0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, l = 0, n = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, l, kvSlice); + uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + n = l * kvSplitSize; + auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + auto v_reorder = value_reorder_ptr + + i * num_head * kvSlice * v_reorder_strideL + + j * kvSlice * v_reorder_strideL + n * rndHeadSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + bool istail = kvBlockSize - b < block_64; + int64_t trans_rows = istail ? kvBlockSize - b : block_64; + do_transpose( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + B_blocked_xform_u8, + trans_rows, + headSize, + kStrideN, + block_64); + if (!headSize_mul64 || istail) { + pad_remain_row_col( + B_blocked_xform_u8, + headSize, + trans_rows, + qk_gemm_K, + block_64, + block_64 + ); + } + at::native::cpublas::pack( + qk_gemm_K, // K + block_64, // N + block_64, // ld_in + block_64, // ld_out + u8_dt, // dt_in + u8_dt, // dt_out + B_blocked_xform_u8, + k_reorder + b * qk_gemm_K); + } + // split headSize to block_64, block_64, block_64 ... + // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] + for (int64_t b = 0; b < rndHeadSize; b += block_64) { + at::native::cpublas::pack( + av_gemm_K, + block_64, + vStrideN, + block_64, + u8_dt, + u8_dt, + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + v_reorder + av_gemm_K * b); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + + at::parallel_for( + 0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init( + begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + scalar_t* total_buf_ptr = total_buf_data + ompIdx * total_size_uint8_per_thread; + int32_t offset = 0; + accum_t* qk_data = reinterpret_cast(total_buf_ptr); + offset += kvSlice * qSplitSize * rndkvSplitSize * 4; + accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * av_gemm_K * 4; + scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); + offset += kvSlice * qk_reduce_strideL; + int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndkvSplitSize * 4; + int32_t* dst_s32_data = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * rndHeadSize * 4; + accum_t* sfm_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* q_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); + offset += qSplitSize * 4; + scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + + for (const auto z : c10::irange(begin, end)) { + (void)z; // Suppress unused variable + + int32_t* kv_sum_ptr = kv_sum_buf_data + + i * num_head * kv_sum_size_per_BH + + j * kv_sum_size_per_BH; + int32_t* k_sum_ptr = kv_sum_ptr; + int32_t* v_sum_ptr = kv_sum_ptr + kvSize; + + // sdpa core + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize sum and max + fill_stub( + sfm_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + a_sum_ptr, static_cast(0), qSplitSize); + fill_stub( + sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qSplitSize, + qk_gemm_K, + qStrideM); + // sum q + if (k_zp != 0) { + _int_sum_b_contiguous_kernel(query_t_padding_ptr, + q_sum_ptr, qBlockSize, headSize, qk_gemm_K, k_zp); + } else { + fill_stub( + q_sum_ptr, static_cast(0), qSplitSize); + } + const int64_t rkvSlice = (kvSize - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { + int64_t n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + + j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + // Calculate q @ k.T + for (int64_t b = 0; b < kvBlockSize; b += block_64) { + at::native::cpublas::brgemm( + qSplitSize, block_64, qk_gemm_K, + qk_gemm_K, // lda + block_64, //ldb + rndkvSplitSize, //ldc, + false, + query_t_padding_ptr, + k_reorder + b * qk_gemm_K, + qk_s32_data + b); + } + + // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 + accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; + if (has_attn_mask) { + mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); + _dequant_mask_max_fusion_kernel( + qk_s32_data, //in + mask_data_offset, //mask_ptr + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvSplitSize, //ldi + mStrideM, //ldm + rndkvSplitSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } else { + _dequant_max_fusion_kernel( + qk_s32_data, //in + q_sum_ptr, //sum_a_ptr + k_sum_ptr + n, //sum_b_ptr + qBlockSize, //M + kvBlockSize, //N + rndkvSplitSize, //ldi + rndkvSplitSize, //ldo + q_zp * k_zp * headSize, //zp_a*zp_b*k=beta + q_scale * k_scale * scaling_factor, //scale_a*scale_b*scale_sdpa=alpha + qk_block_data, //out + sfm_max_ptr // sfm_max_ptr + ); + } + } + // sub max, exp, sum reduce, div sum for softmax + // and quant + // and sum for attention + if (v_zp == 0) { + _sub_exp_sum_div_quant_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlices + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr //sfm_sum_ptr + ); + } else { + _sub_exp_sum_div_quant_sum_fusion_kernel( + qk_data, //in + qBlockSize, //M + kvSplitSize, //N_step + rkvSlice, //NSlice + qSplitSize * rndkvSplitSize, //ldi + qk_reduce_strideL, //ldo + kvSize, //kvSize + rndkvSplitSize, //rndkvSplitSize + av_gemm_K, //av_gemm_K + a_zp, // zp_a=beta1 + v_zp, // zp_b=beta2 + a_scale, // scale_a=alpha + qk_local_data, //local + qk_reduced_data, //out + sfm_max_ptr, //sfm_max_ptr + sfm_sum_ptr, //sfm_sum_ptr + a_sum_ptr //a_sum_ptr + ); + } + // Calculate Softmax(q @ k.T) @ v + auto v_reorder = value_reorder_ptr + + i * num_head * kvSlice * v_reorder_strideL + + j * kvSlice * v_reorder_strideL; + for (int64_t b = 0; b < headSize; b += block_64) { + auto value_reorder_b = v_reorder + b * av_gemm_K; + auto dst_s32_b = dst_s32_data + b; + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, block_64, av_gemm_K, + av_gemm_K, // lda + rndHeadSize, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + value_reorder_b + s * v_reorder_strideL, + dst_s32_b); + } + } + + // After the last gemm, + // do dequant compensation, quant and convert from s32 to int8 + if (a_zp == 0) { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } else { + _dequant_quant_fusion_kernel( + dst_s32_data, //in + a_sum_ptr, //sum_a_ptr + v_sum_ptr, //sum_b_ptr + qBlockSize, //M + headSize, //N + rndHeadSize, //ldi + oStrideM, //ldo + a_zp * v_zp * kvSize, //zp_a*zp_b*k=beta1 + o_zp, //zp_c=beta2 + a_scale * v_scale / o_scale, //scale_a*scale_b/scale_c=alpha + out_data + i * oStrideB + j * oStrideH + m * oStrideM //out + ); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + }); + // Once all computations are done, need to release HW context. + at::native::cpublas::brgemm_release(); +} + + +template +inline typename std::enable_if_t, void> +sdpa_int8_fused_kernel_impl( + bool use_one_parallel_loop, + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + double scale, + float q_scale, + int32_t q_zp, + float k_scale, + int32_t k_zp, + float v_scale, + int32_t v_zp, + float a_scale, + int32_t a_zp, + float o_scale, + int32_t o_zp) { + if (use_one_parallel_loop) { + sdpa_int8_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } else { + sdpa_int8_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } +} + + +#define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Bool, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Double, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, mask_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, mask_t, __VA_ARGS__)) + +void sdpa_int8_fused_kernel( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + double scale, + float q_scale, + int32_t q_zp, + float k_scale, + int32_t k_zp, + float v_scale, + int32_t v_zp, + float a_scale, + int32_t a_zp, + float o_scale, + int32_t o_zp) { + TORCH_CHECK(query.scalar_type() == c10::kByte); + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t q_seq_len = query.size(2); + int64_t kv_seq_len = key.size(2); + int64_t q_split_size = 32; + if (q_seq_len >= 768) { + q_split_size = 256; + } else if (q_seq_len >= 192) { + q_split_size = 64; + } + // Heuristic to decide whether to use one parallel loop or not + // true: one parallel loop for sum+packing+core + // false: three parallel loops for sum, packing, core + uint32_t l2_cache_size = at::cpu::L2_cache_size(); + int64_t num_thread = at::get_num_threads(); + int64_t attn_size = q_split_size * kv_seq_len * sizeof(int32_t) * num_thread; + bool use_one_parallel_loop = (batchSize * num_head > num_thread) && + (attn_size > 1.5 * l2_cache_size); + if (!attn_mask.has_value()) { + if (q_split_size == 256) { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } else if (q_split_size == 64) { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } else { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { + if (q_split_size == 256) { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } else if (q_split_size == 64) { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } else { + sdpa_int8_fused_kernel_impl( + use_one_parallel_loop, + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + } + }); + } +} +#endif // CPU_CAPABILITY_AVX512 + +at::Tensor sdpa_int8_math_kernel( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + double scale, + float q_scale, + int32_t q_zp, + float k_scale, + int32_t k_zp, + float v_scale, + int32_t v_zp, + float a_scale, + int32_t a_zp, + float o_scale, + int32_t o_zp) { + // dequant q/k/v + auto q = (query.to(at::kFloat) - q_zp) * q_scale; + auto k = (key.to(at::kFloat) - k_zp) * k_scale; + auto v = (value.to(at::kFloat) - v_zp) * v_scale; + const auto scaling_factor = calculate_scale(q, scale); + auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; + if (attn_mask.has_value() && attn_mask.value().numel()) { + attn = attn.add(attn_mask.value().to(at::kFloat)); + } + attn = at::softmax(attn, -1); + // quant attn + attn = at::clamp_max( + at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255 + ); + // dequant attn + attn = (attn - a_zp) * a_scale; + auto output = at::matmul(attn, v); + // quant output + output = at::clamp_max( + at::clamp_min(at::round(output / o_scale) + o_zp, 0), 255 + ).to(at::kByte); + return output; +} + + +at::Tensor _scaled_dot_product_int8_cpu( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + std::optional attn_mask, + double dropout_p, + bool is_causal, + double scale, + double q_scale, + int64_t q_zp, + double k_scale, + int64_t k_zp, + double v_scale, + int64_t v_zp, + double a_scale, + int64_t a_zp, + double o_scale, + int64_t o_zp) { + const auto dtype = query.scalar_type(); + TORCH_CHECK(!query.is_nested() && !key.is_nested() && !value.is_nested(), + "_scaled_dot_product_int8_cpu: Only accept plain inputs"); + TORCH_CHECK(!is_causal, + "_scaled_dot_product_int8_cpu: is_causal not supported."); + TORCH_CHECK(dtype == at::ScalarType::Byte, + "_scaled_dot_product_int8_cpu: Expected data type be U8, but got ", dtype, " instead."); + TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "_scaled_dot_product_int8_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); + TORCH_CHECK(dropout_p == 0.0, + "_scaled_dot_product_int8_cpu: Currently do not support dropout > 0"); + TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "_scaled_dot_product_int8_cpu: Q/K/V should have the same head size"); + TORCH_CHECK(!attn_mask.has_value() || + attn_mask.value().scalar_type() == at::kFloat || + attn_mask.value().scalar_type() == at::kBFloat16, + "_scaled_dot_product_int8_cpu: Expected attention mask be float or bf16"); + TORCH_CHECK(!attn_mask.has_value() || + (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), + "_scaled_dot_product_int8_cpu: Attention mask dim in {2, 4}"); + + #ifdef CPU_CAPABILITY_AVX512 + if (at::native::cpublas::could_pack(dtype)) { + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + sdpa_int8_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + return output.transpose(1, 2); + } else { + #endif // CPU_CAPABILITY_AVX512 + return sdpa_int8_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp).transpose(1, 2).contiguous().transpose(1, 2); + #ifdef CPU_CAPABILITY_AVX512 + } + #endif // CPU_CAPABILITY_AVX512 +} + + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::scaled_dot_product_int8", &_scaled_dot_product_int8_cpu); +} + +// } // at::native +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index 82de7528ec..086294404d 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -49,6 +49,9 @@ "mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor", tags=[torch._C.Tag.needs_fixed_stride_order], ) +lib.define( + "scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor" +) def register_custom_op(name): @@ -153,6 +156,94 @@ def _( return _in_feats.new_empty((BS, OC)) +def scaled_dot_product_int8( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + q_scale: float = 1.0, + q_zp: int = 0, + k_scale: float = 1.0, + k_zp: int = 0, + v_scale: float = 1.0, + v_zp: int = 0, + a_scale: float = 1.0, + a_zp: int = 0, + o_scale: float = 1.0, + o_zp: int = 0, +) -> Tensor: + """ + Quantized SDPA with uint8 inputs and outputs. + + Arguments + query: input query tensor, + key: input key tensor, + value: input value tensor, + attn_mask: attention mask tensor, + dropout_p: dropout probability, + is_causal: causal flag, + scale: scaling factor applied prior to softmax, + q_scale: scale for query from linear quantization, + q_zp: zero point for query from linear quantization, + k_scale: scale for key from linear quantization, + k_zp: zero point of key from linear quantization, + v_scale: zero point for value from linear quantization, + v_zp: zero point of value from linear quantization, + a_scale: scale for attention from softmax quantization, + a_zp: zero point for attention from softmax quantization, + o_scale: scale for output from linear quantization, + o_zp: zero point for output from linear quantization, + + Returns + output of quantized SDPA + """ + return torch.ops.torchao.scaled_dot_product_int8.default( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + q_scale, + q_zp, + k_scale, + k_zp, + v_scale, + v_zp, + a_scale, + a_zp, + o_scale, + o_zp, + ) + + +@register_custom_op("torchao::scaled_dot_product_int8") +def _( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + q_scale: float = 1.0, + q_zp: int = 0, + k_scale: float = 1.0, + k_zp: int = 0, + v_scale: float = 1.0, + v_zp: int = 0, + a_scale: float = 1.0, + a_zp: int = 0, + o_scale: float = 1.0, + o_zp: int = 0, +) -> Tensor: + return query + + def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. diff --git a/torchao/prototype/inductor/__init__.py b/torchao/prototype/inductor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/inductor/fx_passes/README.md b/torchao/prototype/inductor/fx_passes/README.md new file mode 100644 index 0000000000..9171f508a8 --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/README.md @@ -0,0 +1,34 @@ +# Inductor FX Passes + +This directory contains the FX passes of Inductor. FX passes are transformations applied to the FX graph to optimize and modify it for better performance and functionality. + +In TorchAO, you can replace the following customized graph passes of Inductor: +- `pre_grad_custom_pass` +- `joint_custom_pre_pass` +- `joint_custom_post_pass` +- `post_grad_custom_post_pass` +- `post_grad_custom_pre_pass` + +## Directory Structure + +- `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion. + +## Getting Started + +To get started with using the FX passes in TorchAO, you can register and apply them to your FX graph as follows: + +```python +from torch._inductor import config +from torch._inductor.pattern_matcher import PatternMatcherPass + +# Example usage +patterns = PatternMatcherPass() # create a pattern matcher pass +_register_patterns(...) # register your own patterns +config.custom_pass = patterns.apply # define the custom pass with the patterns + +``` + +## Limitations + +For now, we can only register one pass as the custom pass. +In the future, it is better to extend it to a list. diff --git a/torchao/prototype/inductor/fx_passes/__init__.py b/torchao/prototype/inductor/fx_passes/__init__.py new file mode 100644 index 0000000000..aae6d5348a --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/__init__.py @@ -0,0 +1,5 @@ +from .int8_sdpa_fusion import _int8_sdpa_init + +__all__ = [ + "_int8_sdpa_init", +] diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py new file mode 100644 index 0000000000..a8f181f2db --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py @@ -0,0 +1,370 @@ +import functools +import itertools + +import torch +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.fx_passes.post_grad import register_lowering_pattern +from torch._inductor.lowering import lowerings as L +from torch._inductor.lowering import make_fallback +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, +) + +__all__ = [ + "_int8_sdpa_init", +] + +make_fallback(torch.ops.torchao.scaled_dot_product_int8.default) + +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_valid_int8_sdpa_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + return ( + query.dtype == torch.uint8 + and key.dtype == torch.uint8 + and value.dtype == torch.uint8 + and query.device.type == "cpu" + and key.device == query.device + and value.device == query.device + ) + + return fn + + +def _register_int8_sdpa_pattern(pattern): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_int8_sdpa_pattern(), + ) + def int8_sdpa(match: Match, *args, **kwargs): + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + inv_scale = kwargs["inv_scale"] + attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None + q_scale = kwargs["q_scale"] + q_zp = kwargs["q_zp"] + k_scale = kwargs["k_scale"] + k_zp = kwargs["k_zp"] + v_scale = kwargs["v_scale"] + v_zp = kwargs["v_zp"] + a_scale = kwargs["a_scale"] + a_zp = kwargs["a_zp"] + o_scale = kwargs["o_scale"] + o_zp = kwargs["o_zp"] + counters["inductor"]["int8_fuse_attention"] += 1 + counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) + + trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) + trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) + trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) + output = L[torch.ops.torchao.scaled_dot_product_int8.default]( + trans_query, + trans_key, + trans_value, + attn_mask, + 0.0, # dropout + False, # is_causal + 1.0 / inv_scale, # scale + q_scale, + q_zp, + k_scale, + k_zp, + v_scale, + v_zp, + a_scale, + a_zp, + o_scale, + o_zp, + ) + trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) + return L[aten.clone.default]( + trans_output, memory_format=torch.contiguous_format + ) + + return int8_sdpa + + +def _get_int8_sdpa_qkv_pattern( + is_batch_size_1: bool, has_convert: bool, input_name: str +): + assert input_name in ["query", "key", "value"] + int8_sdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + KeywordArg(input_name), + Arg(), + ) + if input_name == "key": + # do transpose + int8_sdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + int8_sdpa_qkv_pattern_before_dequant, + Arg(), + ) + int8_sdpa_qkv_basic_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + int8_sdpa_qkv_pattern_before_dequant, + KeywordArg(input_name[0] + "_scale"), + KeywordArg(input_name[0] + "_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_qkv_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_qkv_basic_pattern, + Arg(), + ) + int8_sdpa_qkv_basic_pattern = CallFunction( + aten.expand.default, + int8_sdpa_qkv_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, + int8_sdpa_qkv_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, + CallFunction( + aten.clone.default, + int8_sdpa_qkv_basic_pattern, + memory_format=Arg(), + ), + Arg(), + ) + + +def _get_int8_sdpa_score_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_q_pattern = _get_int8_sdpa_qkv_pattern( + is_batch_size_1, has_convert, "query" + ) + int8_sdpa_k_pattern = _get_int8_sdpa_qkv_pattern( + is_batch_size_1, has_convert, "key" + ) + int8_sdpa_score_basic_pattern = CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + int8_sdpa_q_pattern, + int8_sdpa_k_pattern, + ), + Arg(), + ) + if is_reduced_type and not has_mask: + int8_sdpa_score_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_score_basic_pattern, + Arg(), + ) + if has_mask: + return CallFunction( + aten.add.Tensor, + CallFunction( + aten.div.Tensor, + int8_sdpa_score_basic_pattern, + KeywordArg("inv_scale"), + ), + KeywordArg("attn_mask"), + _users=2, + ) + else: + return CallFunction( + aten.mul.Tensor, + int8_sdpa_score_basic_pattern, + Arg(), + _users=2, + ) + + +def _get_int8_sdpa_exp_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + int8_sdpa_exp_basic_pattern = CallFunction( + aten.sub.Tensor, + int8_sdpa_score_pattern, + CallFunction( + aten.amax.default, + int8_sdpa_score_pattern, + Arg(), + Arg(), + ), + ) + if has_mask: + return CallFunction( + aten.exp.default, + int8_sdpa_exp_basic_pattern, + _users=2, + ) + else: + return CallFunction( + aten.exp.default, + CallFunction( + aten.div.Tensor, + int8_sdpa_exp_basic_pattern, + KeywordArg("inv_scale"), + ), + _users=2, + ) + + +def _get_int8_sdpa_attn_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + int8_sdpa_div_pattern = CallFunction( + aten.div.Tensor, + int8_sdpa_exp_pattern, + CallFunction( + aten.sum.dim_IntList, + int8_sdpa_exp_pattern, + Arg(), + Arg(), + ), + ) + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + int8_sdpa_div_pattern, + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ) + if is_reduced_type: + if has_mask: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_softmax_pattern, + Arg(), + ) + else: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_div_pattern, + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("a_scale"), + KeywordArg("a_zp"), + Arg(), + Arg(), + Arg(), + ) + if has_convert: + int8_sdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + int8_sdpa_softmax_pattern, + Arg(), + ) + return CallFunction( + aten.reshape.default, + CallFunction( + aten.expand.default, + int8_sdpa_softmax_pattern, + Arg(), + ), + Arg(), + ) + + +# Parameters to generate various patterns: +# has_mask: if SDPA has attention mask +# is_batch_size_1: if the batch size is 1 +# is_reduced_type: if autocast is enabled +# has_convert: convert type if dequant out dtype is assigned +def _get_int8_sdpa_final_pattern( + has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool +): + int8_sdpa_v_pattern = _get_int8_sdpa_qkv_pattern( + is_batch_size_1, has_convert, "value" + ) + int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( + has_mask, is_batch_size_1, is_reduced_type, has_convert + ) + return CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + CallFunction( + aten.clone.default, + CallFunction( + aten.permute.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + int8_sdpa_attn_pattern, + int8_sdpa_v_pattern, + ), + Arg(), + ), + Arg(), + ), + memory_format=Arg(), + ), + KeywordArg("o_scale"), + KeywordArg("o_zp"), + Arg(), + Arg(), + Arg(), + ) + + +def _register_int8_sdpa_lowerings(): + for has_mask, is_batch_size_1, is_reduced_type, has_convert in itertools.product( + [True, False], [True, False], [True, False], [True, False] + ): + _register_int8_sdpa_pattern( + _get_int8_sdpa_final_pattern( + has_mask=has_mask, + is_batch_size_1=is_batch_size_1, + is_reduced_type=is_reduced_type, + has_convert=has_convert, + ) + ) + + +@functools.lru_cache(None) +def _int8_sdpa_init(): + _register_int8_sdpa_lowerings() + config.post_grad_custom_pre_pass = patterns.apply