Skip to content

[Do not Land] Re-land "Add INT8 SDPA path for CPU" (#2093) #2183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -55,6 +55,10 @@ def read_version(file_path="version.txt"):
and platform.system() == "Darwin"
)

use_cpp_avx512 = os.getenv("USE_AVX512", "1") == "1" and platform.system() == "Linux"

from torchao.utils import TORCH_VERSION_AT_LEAST_2_7

version_prefix = read_version()
# Version is version.dev year month date if using nightlies and version if not
version = (
@@ -284,6 +288,17 @@ def get_extensions():
["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"]
)

if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7:
if torch._C._cpu._is_avx512_supported():
extra_compile_args["cxx"].extend(
[
"-DCPU_CAPABILITY_AVX512",
"-march=native",
"-mfma",
"-fopenmp",
]
)

if debug_mode:
extra_compile_args["cxx"].append("-g")
if "nvcc" in extra_compile_args:
@@ -305,6 +320,12 @@ def get_extensions():

# Collect C++ source files
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
if IS_WINDOWS:
# Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C
excluded_sources = list(
glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True)
)
sources = [s for s in sources if s not in excluded_sources]

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
217 changes: 217 additions & 0 deletions test/prototype/inductor/test_int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import itertools

import pytest
import torch
import torch.utils.checkpoint
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.test_case import TestCase, run_tests
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.utils.cpp_extension import IS_WINDOWS

import torchao
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7


class SelfAttnLikeModule(torch.nn.Module):
def __init__(
self,
input_dim,
has_mask,
num_attention_heads=None,
attention_head_size=None,
) -> None:
super().__init__()
self.input_dim = input_dim
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
self.softmax = torch.nn.Softmax(dim=-1)
assert num_attention_heads is not None
assert attention_head_size is not None
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
self.dropout = torch.nn.Dropout(0)
self.has_mask = has_mask

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute([0, 2, 1, 3])

def forward(self, x, mask):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
if self.has_mask and mask.dtype != scores.dtype:
scores = scores + mask
attention = self.softmax(scores)
attention = self.dropout(attention)
context_layer = torch.matmul(attention, v)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(
context_layer.size()[:-2] + (self.all_head_size,)
)
return self.dense(context_layer)


class TestSDPAPatternRewriterTemplate(TestCase):
def _clone_inputs(self, inputs):
def clone(x):
if not isinstance(x, torch.Tensor):
return x
return x.clone()

return [clone(x) for x in inputs]

def _check_common(
self,
dot_prod_attention,
args1=None,
contains=True,
atol=1e-5,
has_fuse_pattern=True,
has_dropout=False,
check_train=True,
override_check_equal=False,
dtype=torch.float,
rtol=1.3e-6,
):
if args1 is None:
tensor_shape = (4, 2, 16, 32)
args1 = [
torch.randn(tensor_shape, device=self.device, dtype=dtype),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
torch.randn(tensor_shape, device=self.device, dtype=dtype),
]
else:
args1 = list(args1)
args2 = self._clone_inputs(args1)

for training in [False, True] if check_train else [False]:
for x in itertools.chain(args1[:], args2[:]):
if isinstance(x, torch.Tensor) and x.is_floating_point():
x.requires_grad = training

dropout_arg = [training] if has_dropout else []
torch.manual_seed(1234)
result1 = dot_prod_attention(*(args1 + dropout_arg))

counters.clear()
torch.manual_seed(1234)
compiled_model = torch.compile(dot_prod_attention, fullgraph=True)
result2, source_code = run_and_get_code(
compiled_model,
*(args2 + dropout_arg),
)
source_code = "\n".join(source_code)
if has_fuse_pattern:
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
"torchao.scaled_dot_product_int8",
source_code,
)

# some tests configured with very low dropout where we still want to check equality
if not has_dropout or override_check_equal:
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)

if training:
result1.sum().backward()
result2.sum().backward()
for arg1, arg2 in zip(args1, args2):
if (
isinstance(arg1, torch.Tensor)
and arg1.is_floating_point()
and (not has_dropout or override_check_equal)
):
self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol)

@skipIfRocm
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
)
@pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet")
@config.patch({"freezing": True})
def _test_sdpa_int8_rewriter(self):
from torch.export import export_for_training

import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
X86InductorQuantizer,
)

# pattern is different for bs=1
torch.manual_seed(1234)
for dtype, has_mask, bs in itertools.product(
[torch.float32, torch.bfloat16], [True, False], [56, 1]
):
seqlen, numhead, headsize = 197, 16, 64
mod = SelfAttnLikeModule(
input_dim=headsize * numhead,
has_mask=has_mask,
num_attention_heads=numhead,
attention_head_size=headsize,
).eval()
inputs = (
torch.randn(
(bs, seqlen, headsize * numhead), device=self.device, dtype=dtype
),
torch.randn((bs, 1, 1, seqlen), device=self.device)
if has_mask
else None,
)
enable_autocast = dtype == torch.bfloat16
with (
torch.no_grad(),
torch.amp.autocast(
self.device, enabled=enable_autocast, dtype=torch.bfloat16
),
):
_int8_sdpa_init()
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
quantizer.set_function_type_qconfig(
torch.matmul, quantizer.get_global_quantization_config()
)
export_model = export_for_training(
mod,
inputs,
strict=True,
).module()
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model)
torchao.quantization.pt2e.move_exported_model_to_eval(convert_model)
self._check_common(
convert_model, args1=inputs, check_train=False, atol=1.0
)


if HAS_CPU:

class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
device = "cpu"
test_sdpa_int8_rewriter_cpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter
)


if __name__ == "__main__":
if IS_LINUX:
run_tests()
137 changes: 136 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import math
import sys

import pytest
@@ -14,6 +15,7 @@
parametrize,
)
from torch.testing._internal.optests import opcheck
from torch.utils.cpp_extension import IS_WINDOWS

import torchao
from torchao.dtypes.floatx import from_scaled_tc_floatx
@@ -23,7 +25,11 @@
)
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_7,
compute_max_diff,
)

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
@@ -109,6 +115,135 @@ def test_quant_llm_linear_correctness(
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
assert relative_error < rtol

def _scaled_dot_product_int8_op_ref(
self,
q,
k,
v,
attn_mask=None,
dropout_p=0,
is_causal=False,
q_scale=1.0,
q_zp=0,
k_scale=1.0,
k_zp=0,
v_scale=1.0,
v_zp=0,
a_scale=1.0,
a_zp=0,
o_scale=1.0,
o_zp=0,
):
q = (q.to(torch.float) - q_zp) * q_scale
k = (k.to(torch.float) - k_zp) * k_scale
v = (v.to(torch.float) - v_zp) * v_scale
scale_factor = 1 / math.sqrt(q.size(-1))
attn = q @ k.transpose(-2, -1)
attn = attn * scale_factor
if attn_mask is not None:
attn = attn + attn_mask.to(torch.float)
attn_max = attn.max(dim=-1, keepdim=True).values
attn = attn - attn_max
attn = torch.exp(attn)
attn_sum = torch.sum(attn, dim=-1, keepdim=True)
attn = attn / attn_sum
attn = torch.clamp(torch.round(attn / a_scale) + a_zp, min=0, max=255)
attn = (attn - a_zp) * a_scale
out = attn @ v
out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255)
return out.to(torch.uint8)

@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
)
@pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet")
@parametrize("batch_size", [56, 120])
@parametrize("n_head", [2, 16])
@parametrize("q_seq_len", [18, 89])
@parametrize("kv_seq_len", [100, 253])
@parametrize("head_dim", [32, 64])
@parametrize("mask_dtype", [None, torch.float32, torch.bfloat16])
def test_scaled_dot_product_int8_op(
self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype
):
torch.manual_seed(1234)
device = "cpu"
q_scale = float(1.7907238006591797)
q_zp = int(127)
k_scale = float(1.8039721250534058)
k_zp = int(125)
v_scale = float(1.839004635810852)
v_zp = int(127)
a_scale = float(0.003919653594493866)
a_zp = int(120)
o_scale = float(1.8191684484481812)
o_zp = int(128)
q_shape = [batch_size, q_seq_len, n_head, head_dim]
kv_shape = [batch_size, kv_seq_len, n_head, head_dim]
mask_shape = [batch_size, 1, 1, kv_seq_len]
q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100
k = (
torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
* 100
)
v = (
torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
* 100
)
q = q.to(torch.uint8)
k = k.to(torch.uint8)
v = v.to(torch.uint8)
attn_mask = (
torch.randn(mask_shape, dtype=mask_dtype, device=device)
if mask_dtype is not None
else None
)
q2, k2, v2, attn_mask_2 = (
q.clone(),
k.clone(),
v.clone(),
attn_mask.clone() if mask_dtype is not None else None,
)

math_ref = self._scaled_dot_product_int8_op_ref(
q2,
k2,
v2,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
q_zp=q_zp,
k_scale=k_scale,
k_zp=k_zp,
v_scale=v_scale,
v_zp=v_zp,
a_scale=a_scale,
a_zp=a_zp,
o_scale=o_scale,
o_zp=o_zp,
)
actual = torch.ops.torchao.scaled_dot_product_int8(
q,
k,
v,
attn_mask=attn_mask_2,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
q_zp=q_zp,
k_scale=k_scale,
k_zp=k_zp,
v_scale=v_scale,
v_zp=v_zp,
a_scale=a_scale,
a_zp=a_zp,
o_scale=o_scale,
o_zp=o_zp,
)

self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6)


instantiate_parametrized_tests(TestOps)

1,907 changes: 1,907 additions & 0 deletions torchao/csrc/cpu/int8_sdpa.cpp

Large diffs are not rendered by default.

91 changes: 91 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
@@ -49,6 +49,9 @@
"mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
tags=[torch._C.Tag.needs_fixed_stride_order],
)
lib.define(
"scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor"
)


def register_custom_op(name):
@@ -153,6 +156,94 @@ def _(
return _in_feats.new_empty((BS, OC))


def scaled_dot_product_int8(
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
q_scale: float = 1.0,
q_zp: int = 0,
k_scale: float = 1.0,
k_zp: int = 0,
v_scale: float = 1.0,
v_zp: int = 0,
a_scale: float = 1.0,
a_zp: int = 0,
o_scale: float = 1.0,
o_zp: int = 0,
) -> Tensor:
"""
Quantized SDPA with uint8 inputs and outputs.
Arguments
query: input query tensor,
key: input key tensor,
value: input value tensor,
attn_mask: attention mask tensor,
dropout_p: dropout probability,
is_causal: causal flag,
scale: scaling factor applied prior to softmax,
q_scale: scale for query from linear quantization,
q_zp: zero point for query from linear quantization,
k_scale: scale for key from linear quantization,
k_zp: zero point of key from linear quantization,
v_scale: zero point for value from linear quantization,
v_zp: zero point of value from linear quantization,
a_scale: scale for attention from softmax quantization,
a_zp: zero point for attention from softmax quantization,
o_scale: scale for output from linear quantization,
o_zp: zero point for output from linear quantization,
Returns
output of quantized SDPA
"""
return torch.ops.torchao.scaled_dot_product_int8.default(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
q_scale,
q_zp,
k_scale,
k_zp,
v_scale,
v_zp,
a_scale,
a_zp,
o_scale,
o_zp,
)


@register_custom_op("torchao::scaled_dot_product_int8")
def _(
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
q_scale: float = 1.0,
q_zp: int = 0,
k_scale: float = 1.0,
k_zp: int = 0,
v_scale: float = 1.0,
v_zp: int = 0,
a_scale: float = 1.0,
a_zp: int = 0,
o_scale: float = 1.0,
o_zp: int = 0,
) -> Tensor:
return query


def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor:
"""
Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`.
Empty file.
34 changes: 34 additions & 0 deletions torchao/prototype/inductor/fx_passes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Inductor FX Passes

This directory contains the FX passes of Inductor. FX passes are transformations applied to the FX graph to optimize and modify it for better performance and functionality.

In TorchAO, you can replace the following customized graph passes of Inductor:
- `pre_grad_custom_pass`
- `joint_custom_pre_pass`
- `joint_custom_post_pass`
- `post_grad_custom_post_pass`
- `post_grad_custom_pre_pass`

## Directory Structure

- `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion.

## Getting Started

To get started with using the FX passes in TorchAO, you can register and apply them to your FX graph as follows:

```python
from torch._inductor import config
from torch._inductor.pattern_matcher import PatternMatcherPass

# Example usage
patterns = PatternMatcherPass() # create a pattern matcher pass
_register_patterns(...) # register your own patterns
config.custom_pass = patterns.apply # define the custom pass with the patterns

```

## Limitations

For now, we can only register one pass as the custom pass.
In the future, it is better to extend it to a list.
5 changes: 5 additions & 0 deletions torchao/prototype/inductor/fx_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .int8_sdpa_fusion import _int8_sdpa_init

__all__ = [
"_int8_sdpa_init",
]
370 changes: 370 additions & 0 deletions torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
import functools
import itertools

import torch
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.fx_passes.post_grad import register_lowering_pattern
from torch._inductor.lowering import lowerings as L
from torch._inductor.lowering import make_fallback
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
KeywordArg,
Match,
PatternMatcherPass,
)

__all__ = [
"_int8_sdpa_init",
]

make_fallback(torch.ops.torchao.scaled_dot_product_int8.default)

aten = torch.ops.aten
patterns = PatternMatcherPass()


def _is_valid_int8_sdpa_pattern():
def fn(match):
assert all(k in match.kwargs for k in ("query", "key", "value"))
query = match.kwargs["query"].meta["val"]
key = match.kwargs["key"].meta["val"]
value = match.kwargs["value"].meta["val"]
return (
query.dtype == torch.uint8
and key.dtype == torch.uint8
and value.dtype == torch.uint8
and query.device.type == "cpu"
and key.device == query.device
and value.device == query.device
)

return fn


def _register_int8_sdpa_pattern(pattern):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_int8_sdpa_pattern(),
)
def int8_sdpa(match: Match, *args, **kwargs):
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
inv_scale = kwargs["inv_scale"]
attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None
q_scale = kwargs["q_scale"]
q_zp = kwargs["q_zp"]
k_scale = kwargs["k_scale"]
k_zp = kwargs["k_zp"]
v_scale = kwargs["v_scale"]
v_zp = kwargs["v_zp"]
a_scale = kwargs["a_scale"]
a_zp = kwargs["a_zp"]
o_scale = kwargs["o_scale"]
o_zp = kwargs["o_zp"]
counters["inductor"]["int8_fuse_attention"] += 1
counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes)

trans_query = L[aten.permute.default](query, [0, 2, 1, 3])
trans_key = L[aten.permute.default](key, [0, 2, 1, 3])
trans_value = L[aten.permute.default](value, [0, 2, 1, 3])
output = L[torch.ops.torchao.scaled_dot_product_int8.default](
trans_query,
trans_key,
trans_value,
attn_mask,
0.0, # dropout
False, # is_causal
1.0 / inv_scale, # scale
q_scale,
q_zp,
k_scale,
k_zp,
v_scale,
v_zp,
a_scale,
a_zp,
o_scale,
o_zp,
)
trans_output = L[aten.permute.default](output, [0, 2, 1, 3])
return L[aten.clone.default](
trans_output, memory_format=torch.contiguous_format
)

return int8_sdpa


def _get_int8_sdpa_qkv_pattern(
is_batch_size_1: bool, has_convert: bool, input_name: str
):
assert input_name in ["query", "key", "value"]
int8_sdpa_qkv_pattern_before_dequant = CallFunction(
aten.permute.default,
KeywordArg(input_name),
Arg(),
)
if input_name == "key":
# do transpose
int8_sdpa_qkv_pattern_before_dequant = CallFunction(
aten.permute.default,
int8_sdpa_qkv_pattern_before_dequant,
Arg(),
)
int8_sdpa_qkv_basic_pattern = CallFunction(
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
int8_sdpa_qkv_pattern_before_dequant,
KeywordArg(input_name[0] + "_scale"),
KeywordArg(input_name[0] + "_zp"),
Arg(),
Arg(),
Arg(),
)
if has_convert:
int8_sdpa_qkv_basic_pattern = CallFunction(
torch.ops.prims.convert_element_type.default,
int8_sdpa_qkv_basic_pattern,
Arg(),
)
int8_sdpa_qkv_basic_pattern = CallFunction(
aten.expand.default,
int8_sdpa_qkv_basic_pattern,
Arg(),
)
if is_batch_size_1:
# pattern is different for bs=1
return CallFunction(
aten.reshape.default,
int8_sdpa_qkv_basic_pattern,
Arg(),
)
else:
return CallFunction(
aten.reshape.default,
CallFunction(
aten.clone.default,
int8_sdpa_qkv_basic_pattern,
memory_format=Arg(),
),
Arg(),
)


def _get_int8_sdpa_score_pattern(
has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool
):
int8_sdpa_q_pattern = _get_int8_sdpa_qkv_pattern(
is_batch_size_1, has_convert, "query"
)
int8_sdpa_k_pattern = _get_int8_sdpa_qkv_pattern(
is_batch_size_1, has_convert, "key"
)
int8_sdpa_score_basic_pattern = CallFunction(
aten.reshape.default,
CallFunction(
aten.bmm.default,
int8_sdpa_q_pattern,
int8_sdpa_k_pattern,
),
Arg(),
)
if is_reduced_type and not has_mask:
int8_sdpa_score_basic_pattern = CallFunction(
torch.ops.prims.convert_element_type.default,
int8_sdpa_score_basic_pattern,
Arg(),
)
if has_mask:
return CallFunction(
aten.add.Tensor,
CallFunction(
aten.div.Tensor,
int8_sdpa_score_basic_pattern,
KeywordArg("inv_scale"),
),
KeywordArg("attn_mask"),
_users=2,
)
else:
return CallFunction(
aten.mul.Tensor,
int8_sdpa_score_basic_pattern,
Arg(),
_users=2,
)


def _get_int8_sdpa_exp_pattern(
has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool
):
int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern(
has_mask, is_batch_size_1, is_reduced_type, has_convert
)
int8_sdpa_exp_basic_pattern = CallFunction(
aten.sub.Tensor,
int8_sdpa_score_pattern,
CallFunction(
aten.amax.default,
int8_sdpa_score_pattern,
Arg(),
Arg(),
),
)
if has_mask:
return CallFunction(
aten.exp.default,
int8_sdpa_exp_basic_pattern,
_users=2,
)
else:
return CallFunction(
aten.exp.default,
CallFunction(
aten.div.Tensor,
int8_sdpa_exp_basic_pattern,
KeywordArg("inv_scale"),
),
_users=2,
)


def _get_int8_sdpa_attn_pattern(
has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool
):
int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern(
has_mask, is_batch_size_1, is_reduced_type, has_convert
)
int8_sdpa_div_pattern = CallFunction(
aten.div.Tensor,
int8_sdpa_exp_pattern,
CallFunction(
aten.sum.dim_IntList,
int8_sdpa_exp_pattern,
Arg(),
Arg(),
),
)
int8_sdpa_softmax_pattern = CallFunction(
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
CallFunction(
torch.ops.quantized_decomposed.quantize_per_tensor.default,
int8_sdpa_div_pattern,
KeywordArg("a_scale"),
KeywordArg("a_zp"),
Arg(),
Arg(),
Arg(),
),
KeywordArg("a_scale"),
KeywordArg("a_zp"),
Arg(),
Arg(),
Arg(),
)
if is_reduced_type:
if has_mask:
int8_sdpa_softmax_pattern = CallFunction(
torch.ops.prims.convert_element_type.default,
int8_sdpa_softmax_pattern,
Arg(),
)
else:
int8_sdpa_softmax_pattern = CallFunction(
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
CallFunction(
torch.ops.quantized_decomposed.quantize_per_tensor.default,
CallFunction(
torch.ops.prims.convert_element_type.default,
int8_sdpa_div_pattern,
Arg(),
),
KeywordArg("a_scale"),
KeywordArg("a_zp"),
Arg(),
Arg(),
Arg(),
),
KeywordArg("a_scale"),
KeywordArg("a_zp"),
Arg(),
Arg(),
Arg(),
)
if has_convert:
int8_sdpa_softmax_pattern = CallFunction(
torch.ops.prims.convert_element_type.default,
int8_sdpa_softmax_pattern,
Arg(),
)
return CallFunction(
aten.reshape.default,
CallFunction(
aten.expand.default,
int8_sdpa_softmax_pattern,
Arg(),
),
Arg(),
)


# Parameters to generate various patterns:
# has_mask: if SDPA has attention mask
# is_batch_size_1: if the batch size is 1
# is_reduced_type: if autocast is enabled
# has_convert: convert type if dequant out dtype is assigned
def _get_int8_sdpa_final_pattern(
has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool
):
int8_sdpa_v_pattern = _get_int8_sdpa_qkv_pattern(
is_batch_size_1, has_convert, "value"
)
int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern(
has_mask, is_batch_size_1, is_reduced_type, has_convert
)
return CallFunction(
torch.ops.quantized_decomposed.quantize_per_tensor.default,
CallFunction(
aten.clone.default,
CallFunction(
aten.permute.default,
CallFunction(
aten.reshape.default,
CallFunction(
aten.bmm.default,
int8_sdpa_attn_pattern,
int8_sdpa_v_pattern,
),
Arg(),
),
Arg(),
),
memory_format=Arg(),
),
KeywordArg("o_scale"),
KeywordArg("o_zp"),
Arg(),
Arg(),
Arg(),
)


def _register_int8_sdpa_lowerings():
for has_mask, is_batch_size_1, is_reduced_type, has_convert in itertools.product(
[True, False], [True, False], [True, False], [True, False]
):
_register_int8_sdpa_pattern(
_get_int8_sdpa_final_pattern(
has_mask=has_mask,
is_batch_size_1=is_batch_size_1,
is_reduced_type=is_reduced_type,
has_convert=has_convert,
)
)


@functools.lru_cache(None)
def _int8_sdpa_init():
_register_int8_sdpa_lowerings()
config.post_grad_custom_pre_pass = patterns.apply