Skip to content

feat: caching attempts #3527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions examples/dynamo/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch
from torch.fx import Graph, GraphModule, Node
from typing import Optional, Union, Iterable, List, Tuple
from torch._ops import OpOverloadPacket
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._pytree import _LEAF_SPEC
from torch._export.utils import _detect_fake_mode_from_gm

def get_kv_nodes(gm):
"""
Get the key and value nodes from the graph.
"""
kv_nodes = []
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
q_node, k_node, v_node = node.args[:3]
kv_nodes.append((k_node, v_node))
return kv_nodes

def get_random_tensor_from_node(node: Node) -> torch.Tensor:
"""
Creates a random tensor based on the shape information in a node's metadata.
For symbolic dimensions, extracts the maximum value from the shape environment.

Args:
node: A torch.fx.Node object with metadata containing tensor information

Returns:
A random tensor with shape matching the node's metadata, or None if no valid
tensor information is found
"""
if "val" not in node.meta:
raise ValueError(f"No tensor information found in node metadata for node: {node}")

fake_tensor = node.meta["val"]
shape = []

# Iterate through each dimension and handle symbolic dimensions
for dim in fake_tensor.shape:
if isinstance(dim, torch.SymInt):
# Extract the maximum value from the shape environment
max_val = dim.node.hint
shape.append(max_val)
else:
shape.append(dim)

# Create a random tensor with the determined shape
dtype = fake_tensor.dtype
device = fake_tensor.device
random_tensor = torch.rand(shape, dtype=dtype, device=device)

return random_tensor

def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
"""
Creates random tensors based on the shape information in node metadata.
For symbolic dimensions, extracts the maximum value from the shape environment.

Args:
nodes: List of torch.fx.Node objects with metadata

Returns:
List of random tensors with shapes matching the nodes' metadata
"""
random_tensors = []

for node in nodes:
if isinstance(node, Node):
node_tensor = get_random_tensor_from_node(node)
elif isinstance(node, tuple):
node_tensor_list = []
for n in node:
random_tensor = get_random_tensor_from_node(n)
node_tensor_list.append(random_tensor)
node_tensor = tuple(node_tensor_list)

random_tensors.append(node_tensor)

return random_tensors

def add_graph_input(
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
) -> Node:
"""Add a graph input to the given GraphModule and return the newly created node.

NOTE: function does NOT do any graph canonicalization. This is left to the user!

Args:
gm (GraphModule): The GraphModule to add the input to.
name (str): The name of the input.
val (torch.Tensor): An example tensor to use for the input.
dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET]
"""
# check that no dynamic shape is provided...
if dynamic_shape:
raise NotImplementedError("Dynamic shape not supported for adding graph inputs")

# extract graph and input spec
graph: Graph = gm.graph

in_spec = graph._codegen.pytree_info.in_spec
in_spec_for_args = in_spec.children_specs[0]
orig_args = graph._codegen.pytree_info.orig_args
assert in_spec_for_args.type is tuple

# insert input node after currently last input node
node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1]
with graph.inserting_after(node_last_input):
in_node = graph.placeholder(name)
in_spec_for_args.children_specs.append(_LEAF_SPEC)
orig_args.append(f"arg_{name}")

# update pytree info recursively with __post_init__ starting at leaves
def call_post_init(spec):
for child_spec in spec.children_specs:
call_post_init(child_spec)
spec.__post_init__()

call_post_init(in_spec)

# set fake tensor information if all required information is available
fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
if fake_mode and val is not None and isinstance(val, torch.Tensor):
if isinstance(val, FakeTensor):
fake_tensor = val
else:
fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True)
in_node.meta["val"] = fake_tensor
in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)

# return new node...
return in_node

def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
"""Check if the node is a call to one of the ops."""
if node.op != "call_function":
return False
# check if it's a single op that's provided
if isinstance(ops, OpOverloadPacket):
ops = [ops]

# check if it's the op itself instead of an overload
if any(node.target == op for op in ops):
return True

return False

def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
input_nodes: List[Node] = graph.find_nodes(op="placeholder")
output_nodes: List[Node] = graph.find_nodes(op="output")
return (input_nodes, output_nodes)
236 changes: 236 additions & 0 deletions examples/dynamo/dynamic_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import logging
from typing import Dict, List, Tuple, Union, Sequence, Any

import torch
from torch.fx.node import Target

import torch_tensorrt
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
_aten_lowering_pass,
)
from torch_tensorrt.dynamo.utils import extract_var_range_info
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op
import tensorrt
import torch.utils._pytree as pytree
logger = logging.getLogger(__name__)

@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True)
def cond_converter(
ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: str,
) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]:
"""
Converter for torch.ops.higher_order.cond operation to TensorRT.

This function handles the conversion of PyTorch's conditional operation to TensorRT.
The conditional operation selects between two tensors based on a boolean predicate.

Args:
ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context
target (Target): The target operation to convert
args (Tuple[Argument, ...]): The arguments to the operation
kwargs (Dict[str, Argument]): The keyword arguments to the operation
name (str): The name to give to the TensorRT layer

Returns:
Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s)
"""
if_layer = ctx.net.add_if_conditional()
condition, true_branch, false_branch = args[0], args[1], args[2]
if_layer.set_condition(condition)
output_layer = if_layer.add_output(true_branch, false_branch)
output = output_layer.get_output(0)

return output

def add_kv_as_outputs(gm):
"""
Modifies the graph to add query, key, and value tensors as outputs.

This function identifies all scaled dot-product attention (SDPA) operations
in the graph, creates copies of their query, key, and value inputs, and adds
these copies to the graph's outputs. This allows for accessing these tensors
externally, which is useful for operations like key-value caching.

Args:
graph: The torch.fx.Graph to modify

Returns:
None. The graph is modified in-place.
"""
# list of MHA kernels we would want to detect and replace
mha_ops = {
torch._C._nn.scaled_dot_product_attention,
}

# Find all SDPA nodes in the graph
mha_nodes = []
for node in gm.graph.nodes:
if is_op(node, mha_ops):
mha_nodes.append(node)

# Iterate through each MHA node to extract shape information
for mha_node in mha_nodes:
if "val" in mha_node.meta and len(mha_node.args) >= 3:
# Get the input nodes (query, key, value)
q_node, k_node, v_node = mha_node.args[:3]

# Add the copy nodes as outputs to the graph
output_node = next(node for node in gm.graph.nodes if node.op == "output")

# Get the current output args (typically a tuple)
current_outputs = output_node.args[0]

# If the current output is a tuple, extend it with our new outputs
if isinstance(current_outputs, tuple):
new_outputs = current_outputs + ((k_node, v_node),)
else:
# If there's only one output or it's not a tuple, create a new tuple
new_outputs = (current_outputs, (k_node, v_node))

gm.graph.output(new_outputs)
gm.graph.erase_node(output_node)

return new_outputs




def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True):
"""
Add key-value tensors and index parameters as inputs to the graph.

Args:
gm: The GraphModule to modify
fixed_kv: Boolean indicating whether to use static tensors for KV cache

Returns:
A tuple containing:
- List of (k_input, v_input) node pairs for each SDPA operation
- start_idx input node for slicing operations
- end_idx input node for slicing operations
"""

def get_static_tensor(tensor: torch.Tensor):
key_shape = []
for dim in tensor.shape:
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
key_shape.append(min_max_opt["max"])
else:
key_shape.append(dim)

static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
return static_tensor

keys_values = get_kv_nodes(gm)

kv_inputs = []
for idx, key_value in enumerate(keys_values):
k_val = key_value[0].meta["val"]
v_val = key_value[1].meta["val"]
if fixed_kv:
k_val = get_static_tensor(k_val)
v_val = get_static_tensor(v_val)

# Add new inputs using add_graph_input
k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val)
v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val)
kv_inputs.append((k_input, v_input))

return kv_inputs


def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
Insert a torch.cond operation before each scaled_dot_product_attention operation.

Args:
gm: The FX GraphModule to modify

Returns:
The modified GraphModule
"""
# Find all nodes with scaled_dot_product_attention
sdpa_nodes = []
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention:
sdpa_nodes.append(node)

# Get the is_causal input node
is_causal_node = next((node for node in gm.graph.nodes if node.op == "placeholder" and node.name == "is_causal"), None)

# For each SDPA node, insert a torch.cond operation before it
for idx, sdpa_node in enumerate(sdpa_nodes):

with gm.graph.inserting_before(sdpa_node):
# pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool))
q_node, k_node, v_node = sdpa_node.args[:3]
incoming_key, incoming_value = incoming_keys_values[idx]
# Create nodes for concatenating k with incoming_key and v with incoming_value
concatenated_k_node = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension
kwargs={}
)
concatenated_v_node = gm.graph.create_node(
"call_function",
torch.ops.aten.cat.default,
args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension
kwargs={}
)

# Create the torch.cond node
cond_k_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order.cond,
args=(is_causal_node, concatenated_k_node, k_node),
)

cond_v_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order.cond,
args=(is_causal_node, concatenated_v_node, v_node),
)

sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:]

return gm



@_aten_lowering_pass
def insert_dynamic_kv_cache(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Insert FlashInfer MHA + KV cache ops in the graph"""
"""Perform insertion of kv-caches and attention kernel."""

# Add static key and value as inputs to the graph
kv_inputs = add_kv_and_indices_as_inputs(gm, fixed_kv=True)

# Call the function to add KV as outputs
logits_keys_values = add_kv_as_outputs(gm)

# Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases
gm = insert_torch_cond_before_sdpa(gm, kv_inputs)

gm = clean_up_graph_after_modifications(gm)

new_output_tensors = create_random_output_tensors(logits_keys_values)
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
gm._out_spec = new_out_spec

logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
return gm


Loading
Loading