diff --git a/examples/dynamo/llm/cache_utils.py b/examples/dynamo/llm/cache_utils.py new file mode 100644 index 0000000000..714d1b5b72 --- /dev/null +++ b/examples/dynamo/llm/cache_utils.py @@ -0,0 +1,188 @@ +import torch +from torch.fx import Graph, GraphModule, Node +from typing import Optional, Union, Iterable, List, Tuple +from torch._ops import OpOverloadPacket +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._pytree import _LEAF_SPEC +from torch._export.utils import _detect_fake_mode_from_gm +import torch_tensorrt +import tensorrt +from typing import Any, Dict, Sequence +from torch.fx.node import Target + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True) +def cond_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + """ + Converter for torch.ops.higher_order.cond operation to TensorRT. + + This function handles the conversion of PyTorch's conditional operation to TensorRT. + The conditional operation selects between two tensors based on a boolean predicate. + + Args: + ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context + target (Target): The target operation to convert + args (Tuple[Argument, ...]): The arguments to the operation + kwargs (Dict[str, Argument]): The keyword arguments to the operation + name (str): The name to give to the TensorRT layer + + Returns: + Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) + """ + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = args[0], args[1], args[2] + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + output = output_layer.get_output(0) + + return output + +def get_kv_nodes(gm): + """ + Get the key and value nodes from the graph. + """ + kv_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + q_node, k_node, v_node = node.args[:3] + kv_nodes.append((k_node, v_node)) + return kv_nodes + +def get_random_tensor_from_node(node: Node) -> torch.Tensor: + """ + Creates a random tensor based on the shape information in a node's metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + node: A torch.fx.Node object with metadata containing tensor information + + Returns: + A random tensor with shape matching the node's metadata, or None if no valid + tensor information is found + """ + if "val" not in node.meta: + raise ValueError(f"No tensor information found in node metadata for node: {node}") + + fake_tensor = node.meta["val"] + shape = [] + + # Iterate through each dimension and handle symbolic dimensions + for dim in fake_tensor.shape: + if isinstance(dim, torch.SymInt): + # Extract the maximum value from the shape environment + max_val = dim.node.hint + shape.append(max_val) + else: + shape.append(dim) + + # Create a random tensor with the determined shape + dtype = fake_tensor.dtype + device = fake_tensor.device + random_tensor = torch.rand(shape, dtype=dtype, device=device) + + return random_tensor + +def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: + """ + Creates random tensors based on the shape information in node metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + nodes: List of torch.fx.Node objects with metadata + + Returns: + List of random tensors with shapes matching the nodes' metadata + """ + random_tensors = [] + + for node in nodes: + if isinstance(node, Node): + node_tensor = get_random_tensor_from_node(node) + elif isinstance(node, tuple): + node_tensor_list = [] + for n in node: + random_tensor = get_random_tensor_from_node(n) + node_tensor_list.append(random_tensor) + node_tensor = tuple(node_tensor_list) + + random_tensors.append(node_tensor) + + return random_tensors + +def add_graph_input( + gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None +) -> Node: + """Add a graph input to the given GraphModule and return the newly created node. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to add the input to. + name (str): The name of the input. + val (torch.Tensor): An example tensor to use for the input. + dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] + """ + # check that no dynamic shape is provided... + if dynamic_shape: + raise NotImplementedError("Dynamic shape not supported for adding graph inputs") + + # extract graph and input spec + graph: Graph = gm.graph + + in_spec = graph._codegen.pytree_info.in_spec + in_spec_for_args = in_spec.children_specs[0] + orig_args = graph._codegen.pytree_info.orig_args + assert in_spec_for_args.type is tuple + + # insert input node after currently last input node + node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] + with graph.inserting_after(node_last_input): + in_node = graph.placeholder(name) + in_spec_for_args.children_specs.append(_LEAF_SPEC) + orig_args.append(f"arg_{name}") + + # update pytree info recursively with __post_init__ starting at leaves + def call_post_init(spec): + for child_spec in spec.children_specs: + call_post_init(child_spec) + spec.__post_init__() + + call_post_init(in_spec) + + # set fake tensor information if all required information is available + fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) + if fake_mode and val is not None and isinstance(val, torch.Tensor): + if isinstance(val, FakeTensor): + fake_tensor = val + else: + fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) + in_node.meta["val"] = fake_tensor + in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) + + # return new node... + return in_node + +def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: + """Check if the node is a call to one of the ops.""" + if node.op != "call_function": + return False + # check if it's a single op that's provided + if isinstance(ops, OpOverloadPacket): + ops = [ops] + + # check if it's the op itself instead of an overload + if any(node.target == op for op in ops): + return True + + return False + +def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: + input_nodes: List[Node] = graph.find_nodes(op="placeholder") + output_nodes: List[Node] = graph.find_nodes(op="output") + return (input_nodes, output_nodes) \ No newline at end of file diff --git a/examples/dynamo/llm/dynamic_cache.py b/examples/dynamo/llm/dynamic_cache.py new file mode 100644 index 0000000000..e31939fa99 --- /dev/null +++ b/examples/dynamo/llm/dynamic_cache.py @@ -0,0 +1,203 @@ +import logging +from typing import List, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes, is_op +import torch.utils._pytree as pytree +logger = logging.getLogger(__name__) + + +def add_kv_as_outputs(gm): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + # list of MHA kernels we would want to detect and replace + mha_ops = { + torch._C._nn.scaled_dot_product_attention, + } + + # Find all SDPA nodes in the graph + mha_nodes = [] + for node in gm.graph.nodes: + if is_op(node, mha_ops): + mha_nodes.append(node) + + # Iterate through each MHA node to extract shape information + for mha_node in mha_nodes: + if "val" in mha_node.meta and len(mha_node.args) >= 3: + # Get the input nodes (query, key, value) + q_node, k_node, v_node = mha_node.args[:3] + + # Add the copy nodes as outputs to the graph + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + ((k_node, v_node),) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs, (k_node, v_node)) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + + + +def add_kv_and_indices_as_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors and index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + + # Add is_generate as input + is_generate_input = add_graph_input(gm, "is_generate", True) + is_generate_input.meta["val"] = torch.tensor(True) + + return kv_inputs, is_generate_input + + +def insert_torch_cond_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], is_generate_input: torch.Tensor): + """ + Insert a torch.cond operation before each scaled_dot_product_attention operation. + + Args: + gm: The FX GraphModule to modify + + Returns: + The modified GraphModule + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch._C._nn.scaled_dot_product_attention: + sdpa_nodes.append(node) + + # For each SDPA node, insert a torch.cond operation before it + for idx, sdpa_node in enumerate(sdpa_nodes): + + with gm.graph.inserting_before(sdpa_node): + # pred_node = add_graph_input(gm, "is_generate", torch.tensor(False, dtype=torch.bool)) + q_node, k_node, v_node = sdpa_node.args[:3] + incoming_key, incoming_value = incoming_keys_values[idx] + # Create nodes for concatenating k with incoming_key and v with incoming_value + concatenated_k_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([incoming_key, k_node], 2), # Concatenate along sequence length dimension + kwargs={} + ) + concatenated_v_node = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([incoming_value, v_node], 2), # Concatenate along sequence length dimension + kwargs={} + ) + + # Create the torch.cond node + cond_k_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(is_generate_input, concatenated_k_node, k_node), + ) + + cond_v_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order.cond, + args=(is_generate_input, concatenated_v_node, v_node), + ) + + sdpa_node.args = (q_node, cond_k_node, cond_v_node) + sdpa_node.args[3:] + + return gm + + + +@_aten_lowering_pass +def insert_dynamic_kv_cache( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert FlashInfer MHA + KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + + # Add static key and value as inputs to the graph + kv_inputs, is_generate_input = add_kv_and_indices_as_inputs(gm, fixed_kv=True) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm) + + # Insert torch.cond before each SDPA node which acts toggles between prefill and generate phases + gm = insert_torch_cond_before_sdpa(gm, kv_inputs, is_generate_input) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm + + diff --git a/examples/dynamo/llm/llm_pyt_benchmark.py b/examples/dynamo/llm/llm_pyt_benchmark.py new file mode 100644 index 0000000000..9ae60576a5 --- /dev/null +++ b/examples/dynamo/llm/llm_pyt_benchmark.py @@ -0,0 +1,78 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +import timeit + +USE_CACHE = True +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_NAME = "Qwen/Qwen3-0.6B" +MAX_NEW_TOKENS = 128 + + +def main(): + # Initialize model and tokenizer + print("Loading model and tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.float16, + use_cache=USE_CACHE, + device_map="auto" + ) + # model.generation_config.cache_implementation = "static" + # model.forward = torch.compile(model.forward) + + # Prepare input prompt + word = "What" + # Tokenize the word + word_ids = tokenizer(word, return_tensors="pt").input_ids[0] # Get the first (and only) sequence + # Repeat the token 2048 times + input_ids = word_ids.repeat(1024).unsqueeze(0).to(model.device) # Add batch dimension and move to device + print(f"Input tensor shape: {input_ids.shape}") + + # # Warm-up pass + print("Running warm-up pass...") + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + + # Benchmark loop + print("Running benchmark...") + num_iterations = 10 + total_time = 0 + timings = [] + + for i in range(num_iterations): + start_time = timeit.default_timer() + output_ids = model.generate( + input_ids, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + use_cache=USE_CACHE + ) + end_time = timeit.default_timer() + generation_time = end_time - start_time + total_time += generation_time + timings.append(generation_time) + + # Decode and print first iteration output + # if i == 0: + # output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + # print("\nFirst generation output:") + # print(output_text) + + # Calculate and print statistics + average_time = total_time / num_iterations + print(f"\nPerformance Statistics:") + print(f"Average generation time over {num_iterations} iterations: {average_time*1000:.2f} milliseconds") + print(f"Average tokens per second: {100/average_time:.2f}") + print("\nIndividual timings (ms):") + for i, t in enumerate(timings): + print(f"Iteration {i+1}: {t*1000:.2f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dynamo/llm/run_llm.py b/examples/dynamo/llm/run_llm.py new file mode 100644 index 0000000000..d536cd12e4 --- /dev/null +++ b/examples/dynamo/llm/run_llm.py @@ -0,0 +1,308 @@ +""" +.. _torch_export_gpt2: + +Compiling GPT2 using the dynamo backend +========================================================== + +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" + +import argparse +import copy +import os +import timeit + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer +from contextlib import nullcontext +from utils import export_llm, generate, recordStats, time_generate, generate_with_static_cache, generate_with_dynamic_cache +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + +DEVICE = torch.device("cuda:0") + +def get_model(args): + with torch.no_grad(): + # Supported list of models: + # - meta-llama/Llama-3.2-1B-Instruct + # - meta-llama/Llama-3.2-3B-Instruct + # - meta-llama/Llama-3.1-8B-Instruct + # - Qwen/Qwen2.5-1.5B-Instruct + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + # num_hidden_layers=1 + ) + .eval() + .cuda() + ) + if args.precision == "FP16": + model = model.to(torch.float16) + elif args.precision == "BF16": + model = model.to(torch.bfloat16) + else: + model = model.to(torch.float32) + + return model + + +def compile_torchtrt(model, input_ids, args): + max_seq_len = input_ids.shape[1] + args.num_tokens + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, + # truncate_double=True, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + + +def measure_perf(trt_model, input_signature, backend_name): + # Measure average time for 10 iterations + import timeit + import numpy as np + + total_time = 0 + iterations = 10 + + print("Running warmup iteration...") + # Warmup run + _ = trt_model(*input_signature) + torch.cuda.synchronize() + + print(f"Measuring performance over {iterations} iterations...") + for i in range(iterations): + start_time = timeit.default_timer() + _ = trt_model(*input_signature) + torch.cuda.synchronize() + end_time = timeit.default_timer() + iter_time = end_time - start_time + total_time += iter_time + # print(f"Iteration {i+1}: {iter_time:.4f} seconds") + + avg_time = total_time / iterations + print(f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds") + print(f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second") + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run inference on a model with random input values" + ) + arg_parser.add_argument( + "--model", type=str, default="meta-llama/Llama-3.2-1B-Instruct", help="Name of LLM model" + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", type=str, default="What is parallel programming ?", help="Prompt" + ) + arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") + arg_parser.add_argument( + "--iterations", type=int, default=5, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--min_block_size", type=int, default=1, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--num_tokens", type=int, default=128, help="no. of output tokens to be generated" + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size used for benchmarking" + ) + arg_parser.add_argument( + "--isl", type=int, default=2048, help="Input sequence length used for benchmarking" + ) + arg_parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Enable pytorch run (default: False)" + ) + arg_parser.add_argument( + "--cache", + type=str, + default="", + help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", + ) + arg_parser.add_argument( + "--cudagraph", + action="store_true", + help="Enable cudagraphs (default: False)" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--benchmark", + action="store_true", + help="Enable benchmark (default: False)" + ) + + args = arg_parser.parse_args() + with torch.inference_mode(): + model = get_model(args) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + + # Prepare input for benchmarking or evaluation + if args.benchmark: + input_ids = torch.randint(1, 10000, (args.batch_size, args.isl), dtype=torch.int64).to(model.device) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + # Pyt + pyt_gen_tokens = None + pyt_timings = None + pyt_stats = None + + if args.enable_pytorch_run: + pyt_gen_tokens = generate( + model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + ) + if args.benchmark: + pyt_timings = time_generate( + generate, + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = recordStats( + "PyTorch", pyt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None + ) + + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + elif args.cache == "dynamic": + import dynamic_cache + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + elif args.cache == "dynamic": + trt_gen_tokens = generate_with_dynamic_cache( + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate_with_dynamic_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( + trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if args.benchmark: + trt_stats = recordStats( + "TensorRT", trt_timings, args.precision, batch_size=args.batch_size, compile_time_s=None + ) + + + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) + + if args.enable_pytorch_run: + print(f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}") + + if args.benchmark: + if args.enable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) + print("===================== \n") + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/examples/dynamo/llm/static_cache_v1.py b/examples/dynamo/llm/static_cache_v1.py new file mode 100644 index 0000000000..943718de2e --- /dev/null +++ b/examples/dynamo/llm/static_cache_v1.py @@ -0,0 +1,266 @@ +import logging +from typing import List, Tuple + +import torch +from torch.fx import Node + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + input_ids_meta = input_nodes[0].meta["val"] + seq_len = input_ids_meta.shape[1] + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + + + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + kv_cache_for_sdpa_node = [] + new_keys_values = [] + for key_or_value, current_key_or_value_node in zip([incoming_key, incoming_value], [k_node, v_node]): + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={} + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={} + ) + # =============================================== # + # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={} + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={} + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + cat = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_key_or_value_node, slice_8], 2), + kwargs={} + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + cat.meta.update(key_or_value.meta) + kv_cache_for_sdpa_node.append(cat) + # =============================================== # + # Get the current key and value by indexing the KV cache + slice_9 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(cat,), + kwargs={} + ) + slice_10 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_9, 1), + kwargs={} + ) + slice_11 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_10, 2, None, end_idx_input), + kwargs={} + ) + slice_12 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_11, 3), + kwargs={} + ) + new_keys_values.append(slice_12) + + kv_cache_for_graph.extend(kv_cache_for_sdpa_node) + + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (attn_mask, dropout_p, True) + + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v1( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + + return gm + + diff --git a/examples/dynamo/llm/static_cache_v2.py b/examples/dynamo/llm/static_cache_v2.py new file mode 100644 index 0000000000..e2a40d39f7 --- /dev/null +++ b/examples/dynamo/llm/static_cache_v2.py @@ -0,0 +1,275 @@ +import logging +from typing import List, Tuple + +import torch +from torch.fx import Node + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name+"_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name+"_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + # Get the third last input which should be the last value cache node and store the max_seq_len + input_ids_meta = input_nodes[-3].meta["val"] + seq_len = input_ids_meta.shape[2] + + if isinstance(seq_len, torch.SymInt): + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + else: + max_seq_len = seq_len + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + +def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input): + """ + Create slicing and concatenation nodes for KV cache update. + + This function creates the necessary slicing and concatenation nodes to update the KV cache + during the generation process. It takes the SDPA node, the current KV cache node, and the + incoming KV cache node as input. + Returns: + for a particular SDPA node, a tuple containing: + - List of new current KV nodes + - List of updated incoming KV cache nodes + + """ + + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={} + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={} + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={} + ) + # Concat key_cache[:,:,:start_idx,:] with current key (k) + concat_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_kv_node], 2), + kwargs={} + ) + + # =============================================== # + # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={} + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={} + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + new_incoming_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([concat_keys_or_values, slice_8], 2), + kwargs={} + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + new_incoming_keys_or_values.meta.update(incoming_kv_node.meta) + + return concat_keys_or_values, new_incoming_keys_or_values + +def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node): + """ + Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + assert len(sdpa_node.args) == 6, f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + # For keys + new_current_key_node, new_incoming_key_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input) + # For values + new_current_value_node, new_incoming_value_cache_node = create_kv_cache_update_nodes(gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input) + + # Store the KV cache nodes for the current SDPA node + kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node]) + + # Update the SDPA node arguments with current key and value nodes + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (attn_mask, dropout_p, True) + + # kv_cache_for_graph.extend([k_node, v_node]) + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v2( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm + + diff --git a/examples/dynamo/llm/test_gemma.py b/examples/dynamo/llm/test_gemma.py new file mode 100644 index 0000000000..dc665ce61b --- /dev/null +++ b/examples/dynamo/llm/test_gemma.py @@ -0,0 +1,258 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention, Gemma3DecoderLayer +from transformers.models.gemma3.configuration_gemma3 import Gemma3Config +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + + +ATOL = 1e-5 +RTOL = 1e-5 + + +gemma3_model_name = "google/gemma-3-1b-it" +gemma3_model = AutoModelForCausalLM.from_pretrained( + gemma3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +GEMMA3_CONFIG = gemma3_model.config + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + + +def test_gemma3_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = gemma3_model.model.layers[0].self_attn.to(DTYPE) + + # gemma3 + hidden_states = torch.randn((1, 5, 1152), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 5, 256), dtype=DTYPE).cuda(), torch.randn((1, 5, 256), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def test_gemma3_attention_with_static_cache(args): + + import static_cache_v2 + DTYPE = torch.float32 + model = gemma3_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + +def test_gemma3_decoder(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + model = gemma3_model.model.layers[0].to(DTYPE) + # model.self_attn.is_sliding = False + + # gemma3 + hidden_states = torch.randn((1, 6, 1152), dtype=DTYPE).cuda() + position_embeddings_global = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda()) + position_embeddings_local = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_embeddings_global, position_embeddings_local), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings_global, position_embeddings_local], + enabled_precisions={torch.float32}, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings_global, position_embeddings_local) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + # breakpoint() + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +def test_gemma3_decoder_with_static_cache(args): + + class Gemma3DecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = GEMMA3_CONFIG + self.decoder = Gemma3DecoderLayer( + config=self.config, + layer_idx=0) + self.model = model + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + model = Gemma3DecoderLayerBlock(gemma3_model.model.layers[0].to(DTYPE)) + + import static_cache_v2 + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda() + position_embeddings_global = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda()) + position_embeddings_local = (torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_python_runtime=True) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_gemma3_attention(args) + # test_gemma3_attention_with_static_cache(args) + test_gemma3_decoder(args) + # test_gemma3_decoder_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_llama_components.py b/examples/dynamo/llm/test_llama_components.py new file mode 100644 index 0000000000..c0445e1590 --- /dev/null +++ b/examples/dynamo/llm/test_llama_components.py @@ -0,0 +1,476 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * +ATOL = 1e-5 +RTOL = 1e-5 + + +# llama2_model_name = "meta-llama/Llama-2-7b-hf" +llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" +llama_model = AutoModelForCausalLM.from_pretrained( + llama3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +LLAMA_CONFIG = llama_model.config + +def test_llama_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + model = llama_model.model.layers[0].self_attn.to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + from torch.export._trace import _export + # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False) + ep = _export( + model, + args=(hidden_states, position_embeddings, None), + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + if isinstance(pyt_output, tuple): + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + +def test_llama_attention_with_static_cache(args): + class LlamaAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.attn = LlamaAttention( + config=self.config, + layer_idx=0 + ) + def forward(self, hidden_states, position_embeddings): + attn_output, attn_weights = self.attn(hidden_states, position_embeddings, None) + return attn_output + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = llama_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + import static_cache_v2 + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_decoder(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer( + config=self.config, + layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 6, 64), dtype=DTYPE).cuda(), torch.randn((1, 6, 64), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings], + enabled_precisions=enabled_precisions, + debug=args.debug, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing) + trt_output = trt_model(hidden_states, position_embeddings) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +def test_llama_decoder_with_static_cache(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer( + config=self.config, + layer_idx=0) + self.model = model + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), torch.randn((1, ISL, 64), dtype=DTYPE).cuda()) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes) + import static_cache_v2 + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal], + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = (torch.randn((1, 1, 64), dtype=DTYPE).cuda(), torch.randn((1, 1, 64), dtype=DTYPE).cuda()) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1)) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"position_ids":position_ids} + from torch.export._trace import _export + ep = _export(model, args=(input_ids,), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes, strict=False, allow_complex_guards_as_runtime_asserts=True) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True) + + trt_output = trt_model(input_ids, position_ids) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") + breakpoint() + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def test_llama_model_with_static_cache(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(input_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"input_ids":input_ids, "position_ids":position_ids} + ep = torch.export.export(model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes) + + import static_cache_v2 + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model(input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal) + pyt_output = pyt_output.last_hidden_state + print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda() + position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda() + + # Concatenate the current hidden_states with the previous ones + input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1) + position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1) + is_causal = False + kwarg_inputs = {"input_ids":input_ids_full, "position_ids":position_ids_full} + out_no_cache = model(**kwarg_inputs) + + out_trt, key_cache, value_cache = trt_model(input_ids_curr, position_ids_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + out_pyt = out_no_cache.last_hidden_state[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + input_ids = input_ids_full + position_ids = position_ids_full + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision (default: FP16)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_llama_attention(args) + # test_llama_decoder(args) + test_llama_model(args) + # test_llama_attention_with_static_cache(args) + # test_llama_decoder_with_static_cache(args) + # test_llama_model_with_static_cache(args) \ No newline at end of file diff --git a/examples/dynamo/llm/test_qwen2.5_components.py b/examples/dynamo/llm/test_qwen2.5_components.py new file mode 100644 index 0000000000..37ffbc5dd5 --- /dev/null +++ b/examples/dynamo/llm/test_qwen2.5_components.py @@ -0,0 +1,173 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct" +qwen2_5_model = AutoModelForCausalLM.from_pretrained( + qwen2_5_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +QWEN_CONFIG = qwen2_5_model.config + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + +def test_qwen_apply_rotary_pos_emb(args): + class QwenApplyRotaryPosEmb(nn.Module): + def __init__(self): + super().__init__() + + def rotate_half(self, x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + def forward(self, q, k, cos, sin, unsqueeze_dim=1): + return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE) + # Shapes for Qwen 2.5 + q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + + pyt_output = model(q, k, cos, sin) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len}) + ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[q, k, cos, sin], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(q, k, cos, sin) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32") + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_apply_rotary_pos_emb(args) + test_qwen_attention(args) diff --git a/examples/dynamo/llm/test_qwen3.py b/examples/dynamo/llm/test_qwen3.py new file mode 100644 index 0000000000..e83419b717 --- /dev/null +++ b/examples/dynamo/llm/test_qwen3.py @@ -0,0 +1,175 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TestCase +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3DecoderLayer +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers import AutoModelForCausalLM +import torch_tensorrt +from contextlib import nullcontext +import argparse +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen3_model_name = "Qwen/Qwen3-0.6B" +qwen3_model = AutoModelForCausalLM.from_pretrained( + qwen3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ).eval().cuda() +QWEN_CONFIG = qwen3_model.config + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen3_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() + position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + +def test_qwen3_decoder(args): + + class QwenDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = QWEN_CONFIG + self.model = model + def forward(self, hidden_states, position_ids, position_embeddings): + return self.model(hidden_states, position_ids=position_ids, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + model = QwenDecoderLayerBlock(qwen3_model.model.layers[0].to(DTYPE)) + # qwen3 + hidden_states = torch.randn((1, 5, 1024), dtype=DTYPE).cuda() + position_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() + position_embeddings = (torch.randn((1, 5, 128), dtype=DTYPE).cuda(), torch.randn((1, 5, 128), dtype=DTYPE).cuda()) + + pyt_output = model(hidden_states, position_ids, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export(model, (hidden_states, position_ids, position_embeddings), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[hidden_states, position_ids, position_embeddings], + enabled_precisions={torch.float32}, + debug=args.debug) + trt_output = trt_model(hidden_states, position_ids, position_embeddings) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +def test_qwen3_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + model = qwen3_model.model.to(DTYPE) + # qwen3 + input_ids = torch.randint(0, 5, (1, 5), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1], dtype=torch.int64).cuda().unsqueeze(0) + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + ep = torch.export.export(model, (input_ids, position_ids), dynamic_shapes=dynamic_shapes) + + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile(ep, + inputs=[input_ids, position_ids], + enabled_precisions={torch.float32}, + use_python_runtime=True, + disable_tf32=True, + debug=args.debug) + # breakpoint() + trt_output = trt_model(input_ids, position_ids) + + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}") + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[2] - trt_output[2]))}") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + arg_parser.add_argument("--precision", type=str, default="FP32", help="Precision to use in the model. Options: FP16, BF16, FP32") + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_attention(args) + # test_qwen3_decoder(args) + test_qwen3_model(args) diff --git a/examples/dynamo/llm/test_static_cache.py b/examples/dynamo/llm/test_static_cache.py new file mode 100644 index 0000000000..52807f5e93 --- /dev/null +++ b/examples/dynamo/llm/test_static_cache.py @@ -0,0 +1,385 @@ +import torch +import torch.nn as nn +from torch.export import export +import torch_tensorrt +from contextlib import nullcontext +import argparse +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers import AutoModelForCausalLM +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +import sys +import os + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import register_sdpa + +ATOL = 1e-5 +RTOL = 1e-5 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +class DynamicCacheModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, k1, v1, flag): + def true_fn(q, k, v, k1, v1): + k_new = torch.cat((k, k1), dim=2) + v_new = torch.cat((v, v1), dim=2) + return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) + + def false_fn(q, k, v, k1, v1): + return torch._C._nn.scaled_dot_product_attention(q, k, v) + + out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) + + return 2 * out + +class ModelNoCache(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + return torch._C._nn.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) + +class StaticCacheModel(nn.Module): + def __init__(self): + super().__init__() + + # def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + # new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) + # new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) + # out = torch._C._nn.scaled_dot_product_attention(q, new_key_cache[:, :, :end_idx, :], new_value_cache[:, :, :end_idx, :], dropout_p=0.0, is_causal=is_causal) + + # return out, new_key_cache, new_value_cache + + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) # key_cache[:, :, :6, :] + curr_keys + key_cache[:, : 7: ,: ] + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) + + return out, new_key_cache, new_value_cache + + +def eager_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + """ + Eager implementation of SDPA + """ + import math + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}") + + +def test_no_cache_model_with_torch_tensorrt(args): + """ + Test the no cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + # q = torch.randn(1, 32, 6, 64).cuda() + # k = torch.randn(1, 32, 6, 64).cuda() + # v = torch.randn(1, 32, 6, 64).cuda() + q = torch.load("query.pt") + k = torch.load("key.pt") + v = torch.load("value.pt") + out_no_cache = model_no_cache(q, k, v) + out_eager = eager_sdpa(q, k, v, is_causal=True) + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + # Export the model + exported_program = torch.export.export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2 : q_seq_len}, {2 : q_seq_len}, {2 : q_seq_len}), + strict=False + ) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + min_block_size=1, + ) + out_trt = trt_model(q, k, v) + + print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager") + print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt") + print_diff(out_eager, out_trt, "out_eager vs out_trt") + breakpoint() + + +def test_static_cache_model(args): + """ + Test the static cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Test Prefill + start_idx = 0 + end_idx = 2048 + out_no_cache = model_no_cache(q, k, v) + out_static_cache, new_key_cache, new_value_cache = model_static_cache(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True) + assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_static_cache, new_key_cache, new_value_cache = model_static_cache(q_curr, k_curr, v_curr, new_key_cache, new_value_cache, start_idx, end_idx, is_causal=False) + + assert torch.allclose(out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL) + q = q_full + k = k_full + v = v_full + print("============== test_static_cache passed ==============") + +def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args): + """ + Transform the graph module by adding key and value cache to the graph + """ + gm = exported_program.module() + # Post lower the model + settings = torch_tensorrt.dynamo.conversion.CompilationSettings( + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + return gm + +def test_static_cache_lowering(args): + """ + Test static cache lowering pass applied to the model with no cache and run the graph module + and compare the output with the model with no cache + """ + import static_cache2 + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), + strict=False + ) + + gm = transform_gm_with_kv_cache(exported_program, args) + + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + out_no_cache = model_no_cache(q, k, v) + out_pyt_cache, key_cache, value_cache = gm(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) + assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + is_causal = False + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_pyt_static_cache, key_cache, value_cache = gm(q_curr, k_curr, v_curr, key_cache, value_cache, start_idx, end_idx, is_causal) + assert torch.allclose(out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL) + q = q_full + k = k_full + v = v_full + + print("============== test_static_cache_lowering passed ==============") + +def test_static_cache_export(args): + """ + Test the static cache model export + """ + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + # Export the model + seq_len = torch.export.Dim("seq_len", min=2, max=2048) + seq_len_dyn_dim = {2 : seq_len} + exported_program = export( + model_static_cache, + args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), + dynamic_shapes=(seq_len_dyn_dim, seq_len_dyn_dim, seq_len_dyn_dim, None, None, torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC, None), + strict=False + ) + + +def test_static_cache_with_torch_tensorrt(args): + """ + Test the static cache model with torch_tensorrt + """ + import static_cache_v2 + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2 : q_seq_len}, {2 : kv_seq_len}, {2 : kv_seq_len}), + strict=False + ) + with (torch_tensorrt.logging.debug() if args.debug else nullcontext()): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + # out_eager = eager_sdpa(q, k, v, is_causal=is_causal) + out_no_cache = model_no_cache(q, k, v) + out_trt, trt_key_cache, trt_value_cache = trt_model(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal) + + assert torch.allclose(out_no_cache, out_trt, atol=ATOL, rtol=RTOL), "Prefill TRT logits don't match" + assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL), "Prefill TRT key cache don't match" + assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL), "Prefill TRT value cache don't match" + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + is_causal = True + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_trt, trt_key_cache, trt_value_cache = trt_model(q_curr, k_curr, v_curr, trt_key_cache, trt_value_cache, start_idx, end_idx, is_causal) + # breakpoint() + # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}") + # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}") + # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}") + assert torch.allclose(out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL), f"Generate TRT logits don't match for idx {start_idx}" + assert torch.allclose(trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL), f"Generate TRT key cache don't match for idx {start_idx}" + assert torch.allclose(trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL), f"Generate TRT value cache don't match for idx {start_idx}" + q = q_full + k = k_full + v = v_full + + print("============== test_static_cache_with_torch_tensorrt passed ==============") + + +def main(): + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug (default: False)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_no_cache_model_with_torch_tensorrt(args) + # test_static_cache_model(args) + # test_static_cache_lowering(args) + test_static_cache_with_torch_tensorrt(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dynamo/llm/utils.py b/examples/dynamo/llm/utils.py new file mode 100644 index 0000000000..c43f90acc5 --- /dev/null +++ b/examples/dynamo/llm/utils.py @@ -0,0 +1,216 @@ +import torch +from transformers import StoppingCriteriaList +from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, +) +import numpy as np +import copy +import timeit + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, args=(inputs,), kwargs={"position_ids":position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + args=(inputs,), + kwargs={"position_ids":position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + +def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-2] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) + + return tuple(zeroed_kv_cache_inputs) + +def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-1] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append(torch.zeros(input.meta["val"].shape, dtype=input.meta["val"].dtype, device=torch.device("cuda:0"))) + + return tuple(zeroed_kv_cache_inputs) + + +def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + stopping_criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=max_output_seq_length), + EosTokenCriteria(eos_token_id=eos_token_id), + ] + ) + isl = input_seq.shape[1] + osl = max_output_seq_length - isl + + num_tokens_generated = 0 + while num_tokens_generated < osl: + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + outputs = model(input_seq, position_ids=position_ids) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + num_tokens_generated += 1 + # TODO: Handle batch in this check + if not benchmark and stopping_criteria(input_seq, logits).item(): + break + + return input_seq + +def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with static KV cache. + """ + start_idx = 0 + end_idx = input_seq.shape[1] + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL + num_tokens_generated = 0 + kv_cache = get_zeroed_static_cache_inputs(model) + while end_idx < max_output_seq_length: + position_ids = torch.tensor([[start_idx]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + start_idx = end_idx + end_idx = start_idx + 1 + return output_seq + +def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with dynamic KV cache. + """ + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + num_output_tokens = max_output_seq_length - input_seq.shape[1] + num_tokens_generated = 0 + kv_cache = get_zeroed_dynamic_cache_inputs(model) + last_position_id = position_ids[-1, -1].item() + breakpoint() + while num_tokens_generated < num_output_tokens: + is_generate = False if input_seq.shape[1] > 1 else True + position_ids = torch.tensor([[last_position_id+1]], dtype=torch.int64).cuda() if input_seq.shape[1] == 1 else position_ids + input_signature = (input_seq, position_ids, *kv_cache, is_generate) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + last_position_id += 1 + return output_seq + + +def time_generate( + generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn( + model, inputs, output_seq_length, eos_token_id + ) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings + + +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): + """ + Records different timing stats and adds it to the result + """ + times = np.array(timings) + speeds = batch_size / times + time_mean = np.mean(times).item() + time_med = np.median(times).item() + time_99th = np.percentile(times, 99).item() + time_std = np.std(times, ddof=0).item() + speed_mean = np.mean(speeds).item() + speed_med = np.median(speeds).item() + + stats = { + "Backend": backend, + "Precision": precision, + "Batch size": batch_size, + "Median(FPS)": speed_med, + "Mean(FPS)": speed_mean, + "Median-Latency(ms)": time_med * 1000, + "Mean-Latency(ms)": time_mean * 1000, + "Latency-StdDev(ms)": time_std * 1000, + "Compile Time(s)": compile_time_s, + } + return stats \ No newline at end of file diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py index 7436f31939..906673a806 100644 --- a/examples/dynamo/register_sdpa.py +++ b/examples/dynamo/register_sdpa.py @@ -19,11 +19,13 @@ # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention # This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) TORCH_TRT_DECOMPOSITIONS.pop( - torch.ops.aten._scaled_dot_product_efficient_attention.default + torch.ops.aten._scaled_dot_product_efficient_attention.default, None +) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_flash_attention.default, None ) -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, @@ -59,6 +61,7 @@ def replace_variants_of_sdpa( elif len(node.args) == 5: query, key, value, attn_mask, is_causal = node.args dropout_p = 0.0 + else: raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" @@ -71,6 +74,10 @@ def replace_variants_of_sdpa( query, key, value, dropout_p, is_causal, return_debug_mask = ( node.args ) + if len(node.args) == 5: + query, key, value, dropout_p, is_causal = ( + node.args + ) elif len(node.args) == 3: query, key, value = node.args dropout_p = 0.0 @@ -79,20 +86,16 @@ def replace_variants_of_sdpa( raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" ) - if attn_mask is not None: - logger.warning( - f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." - ) - - modified_input_args = (query, key, value, None, dropout_p, is_causal) - + + logger.warning(f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations.") + modified_input_args = (query, key, value, None, dropout_p, True) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None)}, + kwargs={"scale": node.kwargs.get("scale", None), "use_fp32_acc": settings.use_fp32_acc}, ) # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. @@ -113,7 +116,7 @@ def replace_variants_of_sdpa( # Clean up the graph clean_up_graph_after_modifications(gm) - logger.info( + logger.debug( "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" ) return gm diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 903324dff5..c60ad915dd 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -62,25 +62,15 @@ def scaled_dot_product_attention( ) -> TRTTensor: # TODO: Handle attn_mask and is_causal arguments in the future query, key, value, attn_mask, dropout_p, is_causal = args - logger.info( - "Ignoring attn_mask and is_causal arguments provided by the original graph. " - "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " - "and for generate phase, is_causal=False since we pass only 1 input token at a time" - ) # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN + is_causal = True # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - mm = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) + use_fp32_acc = kwargs.get("use_fp32_acc", False) + query_dtype = query.dtype + if scale is None: scale = query.shape[-1] if scale < 0: @@ -90,80 +80,110 @@ def scaled_dot_product_attention( else: # static shape sqrt_scaled = math.sqrt(scale) - scaled = impl.elementwise.div( + key = impl.elementwise.div( ctx, target, source_ir, name + "_scale", - mm, + key, sqrt_scaled, ) else: - scaled = impl.elementwise.mul( + key = impl.elementwise.mul( ctx, target, source_ir, name + "_scale", - mm, + key, scale, ) - - # If is_causal is True, we need to generate a causal mask - if is_causal: - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, 2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor + + if use_fp32_acc and query_dtype == trt.float16: + query = cast_trt_tensor( + ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + key = cast_trt_tensor( + ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir ) - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, + + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + if use_fp32_acc: + mm = cast_trt_tensor( + ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask + + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask ) - else: - scaled_add_attn_bias = scaled - # Create a if condition to check if is_causal is True - if isinstance(is_causal, TRTTensor): - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - scaled_add_attn_bias = output_layer.get_output(0) + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) + if use_fp32_acc: + softmax = cast_trt_tensor( + ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir + ) + value = cast_trt_tensor( + ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir + ) out = impl.matmul.matrix_multiply( ctx, target, @@ -172,5 +192,9 @@ def scaled_dot_product_attention( softmax, value, ) + if use_fp32_acc: + out = cast_trt_tensor( + ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir + ) return out diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py deleted file mode 100644 index 25ad99c12d..0000000000 --- a/examples/dynamo/utils.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from transformers import StoppingCriteriaList -from transformers.generation.stopping_criteria import ( - EosTokenCriteria, - MaxLengthCriteria, -) - - -def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): - """ - Exports the LLM model into an ExportedProgram with dynamic shapes. - In the case of guard failures due to some PyTorch kernel implements, we also - try to re-export the graph by expressing them as runtime assert nodes - """ - with torch.no_grad(): - # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 - seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) - try: - print("Trying to export the model using torch.export.export()..") - # strict=False only enables aotautograd tracing and excludes dynamo. - ep = torch.export.export( - model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False - ) - except: - print( - "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" - ) - # This API is used to express the constraint violation guards as asserts in the graph. - ep = torch.export._trace._export( - model, - (inputs,), - dynamic_shapes=({1: seq_len},), - strict=False, - allow_complex_guards_as_runtime_asserts=True, - ) - - return ep - - -def generate(model, input_seq, max_tokens, eos_token_id): - """ - Greedy decoding of the model. This generates up to max_tokens. - """ - # Max length of output seq = current input_seq length + max_tokens allowed to generate - max_output_seq_length = input_seq.shape[1] + max_tokens - stopping_criteria = StoppingCriteriaList( - [ - MaxLengthCriteria(max_length=max_output_seq_length), - EosTokenCriteria(eos_token_id=eos_token_id), - ] - ) - - while True: - outputs = model(input_seq) - logits = outputs.logits - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1) - input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) - # TODO: Handle batch in this check - if stopping_criteria(input_seq, logits).item(): - break - - return input_seq diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 831ce37305..3700057fd7 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -788,6 +788,29 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) + + # Store the original input spec for later use + original_in_spec = getattr(gm, '_in_spec', None) + original_out_spec = getattr(gm, '_out_spec', None) + + # Function to preserve and restore module specs + def preserve_module_specs(in_spec, out_spec, target_module): + """ + Applies input and output specs to the target module. + + Args: + in_spec: The input spec to apply + out_spec: The output spec to apply + target_module: The module to apply specs to + """ + # Apply specs to target module + if in_spec is not None: + target_module._in_spec = in_spec + if out_spec is not None: + target_module._out_spec = out_spec + + return target_module + # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -835,6 +858,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: continue submodule_node_dict[node.name] = node + preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -1171,7 +1195,7 @@ def convert_exported_program_to_serialized_trt_engine( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, } - + settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..4046f5c54d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1896,7 +1896,7 @@ def aten_ops_minimum( args[1], ) - +@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True) def aten_ops_sub( diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 825be75076..6df05f6940 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -171,6 +171,7 @@ aten.upsample_bilinear2d.vec, aten.upsample_trilinear3d.vec, aten.upsample_bicubic2d.vec, + aten.linear, } diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..6e2019ad71 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,6 +6,7 @@ from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast @@ -25,7 +26,7 @@ replace_max_pool_with_indices, remove_assert_nodes, accumulate_fp32_matmul, - remove_num_users_is_0_nodes, + # remove_num_users_is_0_nodes, ] if not is_tegra_platform(): diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..172d902a40 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -55,7 +55,6 @@ def constant_fold( del cf logger.debug(f"Graph after constant folding:\n{gm.graph}") - return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py new file mode 100644 index 0000000000..89558acade --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -0,0 +1,172 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + sdpa_nodes = [node for node in gm.graph.nodes if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default] + breakpoint() + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + breakpoint() + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + breakpoint() + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + breakpoint() + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py new file mode 100644 index 0000000000..9aac192316 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_PythonCUDAGraphModule.py @@ -0,0 +1,771 @@ +from __future__ import annotations + +import logging +from contextlib import nullcontext +from tempfile import tempdir +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import tensorrt as trt +import torch +import torch_tensorrt +from torch.nn import Module +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import Platform, dtype +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.logging import TRT_LOGGER +from torch_tensorrt.runtime._utils import ( + _is_switch_required, + _select_rt_device, + multi_gpu_device_check, +) + +logger = logging.getLogger(__name__) + + +class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] + def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: + trt.IOutputAllocator.__init__(self) + self.buffers: Dict[str, torch.Tensor] = {} + self.shapes: Dict[str, Tuple[int, ...]] = {} + self.dtypes: Dict[str, torch.dtype] = output_dtypes + + def reallocate_output_async( + self, + tensor_name: str, + memory: int, + size: int, + alignment: int, + stream: torch.cuda.Stream, + ) -> Any: + shape = (size,) + if tensor_name not in self.buffers: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + else: + if self.buffers[tensor_name].shape != shape: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + return self.buffers[tensor_name].data_ptr() + + def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None: + self.shapes[tensor_name] = tuple(shape) + + +class TorchTRTRuntimeStates: + def __init__(self, new_cudagraphs: bool): + # Indicates whether CUDAGraphs were enabled in the previous execute_engine + self.old_cudagraphs = new_cudagraphs + # Indicates whether pre-allocated output was enabled in the previous execute_engine + self.old_pre_allocated_outputs = False + # Indicates whether context has changed + self.context_changed = False + + def set_runtime_states( + self, + new_cudagraphs: bool, + new_pre_allocated_output: bool, + shape_changed: bool, + ) -> Tuple[bool, bool, bool]: + # Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs + # based on the current and previous states, as well as input shape has changed + need_cudagraphs_record = False + can_use_pre_allocated_outputs = False + need_cudagraphs_reset = False + + # CUDA Graph recording is needed if CUDA graphs is enabled and: + # - CUDA graphs were previously disabled + # - or the shape has changed + # - or the execution context has changed (e.g., weight streaming) + if new_cudagraphs and ( + not self.old_cudagraphs or shape_changed or self.context_changed + ): + need_cudagraphs_record = True + + # Pre-allocated output can be used when previous and current state are true without shape change + if ( + self.old_pre_allocated_outputs + and new_pre_allocated_output + and (not shape_changed) + ): + can_use_pre_allocated_outputs = True + + if not new_cudagraphs or shape_changed or self.context_changed: + need_cudagraphs_reset = True + + self.old_cudagraphs = new_cudagraphs + self.old_pre_allocated_outputs = new_pre_allocated_output + # reset flag + self.context_changed = False + + return ( + need_cudagraphs_record, + can_use_pre_allocated_outputs, + need_cudagraphs_reset, + ) + + +class PythonTorchTensorRTModule(Module): # type: ignore[misc] + """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. + + This module is backed by the Torch-TensorRT runtime and is only compatible with + FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment. + """ + + def __init__( + self, + serialized_engine: Optional[bytes] = None, + input_binding_names: Optional[List[str]] = None, + output_binding_names: Optional[List[str]] = None, + *, + name: str = "", + settings: CompilationSettings = CompilationSettings(), + weight_name_map: Optional[dict[Any, Any]] = None, + requires_output_allocator: bool = False, + ): + """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs + a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine + + Arguments: + serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray + input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules + output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned + + Keyword Arguments: + name (str): Name for module + settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name + requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) + + Example: + + .. code-block:: py + + trt_module = PythonTorchTensorRTModule( + engine_str, + input_binding_names=["x"], + output_binding_names=["output"], + name="my_module", + settings=CompilationSettings(device=torch.cuda.current_device) + ) + + """ + self.context: Any + super(PythonTorchTensorRTModule, self).__init__() + self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) + + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + + self.name = name + self._input_buffers: Dict[str, List[torch.Tensor]] = {} + self._output_buffers: Dict[str, List[torch.Tensor]] = {} + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self._caller_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: Optional[torch.cuda.Stream] = None + + # TODO: Make the below a Dictionary {shape: cudagraph} + self.shape_key_to_cudagraph: Dict[str, torch.cuda.CUDAGraph] = {} + + # See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98 + # Unused currently - to be used by Dynamic Shape support implementation + self.memory_pool = None + + self.serialized_engine = serialized_engine + self.input_names = ( + input_binding_names if input_binding_names is not None else [] + ) + self.output_names = ( + output_binding_names if output_binding_names is not None else [] + ) + self.initialized = False + self.target_device_id = ( + settings.device.gpu_id + if settings.device is not None + else Device._current_device().gpu_id + ) + self.target_device_properties = torch.cuda.get_device_properties( + self.target_device_id + ) + self.profiling_enabled = settings.debug if settings.debug is not None else False + self.settings = settings + self.engine = None + self.weight_name_map = weight_name_map + self.target_platform = Platform.current_platform() + self.runtime_states = TorchTRTRuntimeStates( + torch_tensorrt.runtime.get_cudagraphs_mode() + ) + + self.cudagraphs_enabled = False + self.pre_allocated_outputs: List[torch.Tensor] = [] + self.use_pre_allocated_outputs = False + + self.requires_output_allocator = requires_output_allocator + self.output_allocator: Optional[DynamicOutputAllocator] = None + self.use_output_allocator_outputs = False + + if self.serialized_engine is not None and not self.settings.lazy_engine_init: + self.setup_engine() + + def get_streamable_device_memory_budget(self) -> Any: + return self.engine.streamable_weights_size + + def get_automatic_device_memory_budget(self) -> Any: + return self.engine.get_weight_streaming_automatic_budget() + + def get_device_memory_budget(self) -> Any: + return self.engine.weight_streaming_budget_v2 + + def set_device_memory_budget(self, budget_bytes: int) -> int: + # Recreating the context because weight streaming budget cannot be modified while there are active context. + if self.context is not None: + del self.context + budget_bytes = self._set_device_memory_budget(budget_bytes) + self.context = self.engine.create_execution_context() + self.runtime_states.context_changed = True + return budget_bytes + + def _set_device_memory_budget(self, budget_bytes: int) -> int: + # Disable weight streaming for invalid budget size + if budget_bytes < 0: + budget_bytes = self.get_streamable_device_memory_budget() + self.engine.weight_streaming_budget_v2 = budget_bytes + if self.engine.weight_streaming_budget_v2 != budget_bytes: + logger.error(f"Failed to set weight streaming budget to {budget_bytes}") + budget_bytes = self.engine.weight_streaming_budget_v2 + if self.get_streamable_device_memory_budget() == budget_bytes: + logger.warning("Weight streaming is disabled") + + return budget_bytes + + def set_default_device_memory_budget(self) -> int: + budget_bytes = self.get_automatic_device_memory_budget() + # Set automatic weight streaming budget as default when context is created + logger.debug(f"Weight streaming budget set to {budget_bytes}B") + return self._set_device_memory_budget(budget_bytes) + + def setup_engine(self) -> None: + assert ( + self.target_platform == Platform.current_platform() + ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" + + self.initialized = True + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) + if self.settings.enable_weight_streaming: + self.set_default_device_memory_budget() + self.context = self.engine.create_execution_context() + assert self.engine.num_io_tensors == ( + len(self.input_names) + len(self.output_names) + ) + + self.input_dtypes = [ + dtype._from(self.engine.get_tensor_dtype(input_name)) + for input_name in self.input_names + ] + self.input_shapes = [ + self.engine.get_tensor_shape(input_name) for input_name in self.input_names + ] + self.output_dtypes = [ + dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype) + for output_name in self.output_names + ] + self.output_shapes = [ + self.engine.get_tensor_shape(output_name) + for output_name in self.output_names + ] + + if self.requires_output_allocator: + self.create_output_allocator() + + if torch_tensorrt.runtime.get_cudagraphs_mode(): + self.cudagraph = torch.cuda.CUDAGraph() + + def _check_initialized(self) -> None: + if not self.initialized: + raise RuntimeError("PythonTorchTensorRTModule is not initialized.") + + def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: + state_dict[prefix + "engine"] = self.serialized_engine + state_dict[prefix + "input_names"] = self.input_names + state_dict[prefix + "output_names"] = self.output_names + state_dict[prefix + "platform"] = self.target_platform + + def _load_from_state_dict( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Any, + strict: Any, + missing_keys: Any, + unexpected_keys: Any, + error_msgs: Any, + ) -> None: + self.serialized_engine = state_dict[prefix + "engine"] + self.input_names = state_dict[prefix + "input_names"] + self.output_names = state_dict[prefix + "output_names"] + self.target_platform = state_dict[prefix + "platform"] + + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + self.setup_engine() + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state.pop("engine", None) + state.pop("context", None) + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + self.setup_engine() + + def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + result.__setstate__(self.__getstate__()) + return result + + def _reset_captured_graph(self, inputs_shape_key: str = None) -> None: + if inputs_shape_key in self.shape_key_to_cudagraph: + self.shape_key_to_cudagraph[inputs_shape_key].reset() + self.shape_key_to_cudagraph.pop(inputs_shape_key) + + def __del__(self) -> None: + self._reset_captured_graph() + + def setup_input_tensors( + self, + contiguous_inputs: List[torch.Tensor], + cudagraphs_enabled: bool, + need_cudagraphs_record: bool, + inputs_shape_key: str = None, + ) -> None: + for i, input_name in enumerate(self.input_names): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) + + assert ( + contiguous_inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." + + is_shape_tensor_input = self.engine.is_shape_inference_io(input_name) + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + if is_shape_tensor_input: + self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].cpu().clone() + else: + self._input_buffers[inputs_shape_key][i] = contiguous_inputs[i].clone() + + # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers + # as per TensorRT requirements + if is_shape_tensor_input: + # Shape tensor inputs are casted to int64 explicitly + # Currently Torch CPU pointers are not working; numpy pointers are used instead + # to refer to underlying memory + inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64) + inputs_cpu_numpy = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() + # if cudagraphs_enabled: + # self._input_buffers[inputs_shape_key][i].copy_(inputs_cpu) + # self.context.set_tensor_address(input_name, self._input_buffers[inputs_shape_key][i].numpy().copy().ctypes.data) + # else: + self.context.set_tensor_address(input_name, inputs_cpu_numpy.ctypes.data) + else: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) + if cudagraphs_enabled: + self._input_buffers[inputs_shape_key][i].copy_(contiguous_inputs[i]) + self.context.set_tensor_address( + input_name, self._input_buffers[inputs_shape_key][i].data_ptr() + ) + else: + self.context.set_tensor_address( + input_name, contiguous_inputs[i].data_ptr() + ) + + def create_output_tensors(self) -> List[torch.Tensor]: + # create output tensors + outputs: List[torch.Tensor] = [] + + for o, _ in enumerate(self.output_names): + output = torch.empty( + size=self.output_shapes[o], + dtype=self.output_dtypes[o], + device=torch.cuda.current_device(), + ) + outputs.append(output) + return outputs + + def set_pre_allocated_outputs(self, enable: bool) -> None: + self.use_pre_allocated_outputs = enable + + def set_use_output_allocator(self, enable: bool) -> None: + self.use_output_allocator_outputs = enable + + def create_output_allocator(self) -> None: + if self.output_allocator is None: + output_dtypes_dict = {} + for o, output_name in enumerate(self.output_names): + output_dtypes_dict[output_name] = self.output_dtypes[o] + self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) + + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + + def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: + # print(f"**************** first key cache shape: {inputs[1].shape}") + shape_changed, inputs_shape_key = self.validate_input_shapes(inputs) + ( + need_cudagraphs_record, + can_use_pre_allocated_outputs, + need_cudagraphs_reset, + ) = self.runtime_states.set_runtime_states( + self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed + ) + + if need_cudagraphs_reset: + self._reset_captured_graph(inputs_shape_key) + + if need_cudagraphs_record: + self._input_buffers[inputs_shape_key] = [None] * len(self.input_names) + self._output_buffers[inputs_shape_key] = [None] * len(self.output_names) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + + self.setup_input_tensors( + contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record, inputs_shape_key + ) + + if shape_changed: + # Check if input shapes can be inferred. + uninferred_input_names = self.context.infer_shapes() + if uninferred_input_names: + logger.warning( + f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ + This could happen if the input tensor addresses/shapes haven't been configured correctly" + ) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + if can_use_pre_allocated_outputs: + outputs = self.pre_allocated_outputs + else: + self.output_shapes = [ + tuple(self.context.get_tensor_shape(output_name)) + for output_name in self.output_names + ] + if DYNAMIC_DIM in self.output_shapes: + raise ValueError( + "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." + ) + outputs = self.create_output_tensors() + + for o, output_name in enumerate(self.output_names): + if need_cudagraphs_record: + self._output_buffers[inputs_shape_key][o] = outputs[o].clone() + + if self.cudagraphs_enabled: + self.context.set_tensor_address( + output_name, self._output_buffers[inputs_shape_key][o].data_ptr() + ) + else: + self.context.set_tensor_address( + output_name, outputs[o].data_ptr() + ) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + if self.cudagraphs_enabled: + if need_cudagraphs_record: + + self.shape_key_to_cudagraph[inputs_shape_key] = torch.cuda.CUDAGraph() + + if self.profiling_enabled: + self.shape_key_to_cudagraph[inputs_shape_key].enable_debug_mode() + + with torch.cuda.graph( + self.shape_key_to_cudagraph[inputs_shape_key], stream=self._engine_stream + ): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) + + if self.profiling_enabled: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + self.shape_key_to_cudagraph[inputs_shape_key].debug_dump( + f"{tempdir}/{self.name}_cudagraph.dot" + ) + + self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore + + else: + self.context.execute_async_v3(self._engine_stream.cuda_stream) + + self._caller_stream.wait_stream(self._engine_stream) + + if self.use_pre_allocated_outputs: + self.pre_allocated_outputs = self.create_output_tensors() + + if self.cudagraphs_enabled: + for idx, o in enumerate(outputs): + o.copy_(self._output_buffers[inputs_shape_key][idx]) + + if len(outputs) == 1: + return outputs[0] + + return outputs + + def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: + assert ( + not torch_tensorrt.runtime.get_cudagraphs_mode() + ), "CUDA Graphs are not compatible with OutputAllocator." + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + + self.setup_input_tensors(contiguous_inputs, False, False) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:SetupOutputAllocator" + ) + if self.profiling_enabled + else nullcontext() + ): + self.create_output_allocator() + # need to set output allocator every run + for output_name in self.output_names: + if not self.context.set_output_allocator( + output_name, self.output_allocator + ): + raise RuntimeError( + f"Failed to set output allocator for {output_name}" + ) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) # The OutputAllocator is called by execute_async_v3() + + self._caller_stream.wait_stream(self._engine_stream) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + outputs = [] + assert self.output_allocator is not None + for o, output_name in enumerate(self.output_names): + shape = self.output_allocator.shapes.get(output_name, None) + dtype = self.output_dtypes[o] + output = ( + self.output_allocator.buffers.get(output_name, None) + .clone() + .detach() + ) + prod = int(torch.prod(torch.tensor(shape))) + # When using the OutputAllocator, the allocated buffer might be larger than the size of the output, + # so we need to reshape the buffer to the output shape + output = output.reshape(-1).view(dtype)[:prod].reshape(shape) + outputs.append(output) + + if len(outputs) == 1: + return outputs[0] + + return outputs + + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + + # Run forward function + contiguous_inputs: List[torch.Tensor] = [ + (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) + for i in inputs + ] + with ( + torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") + if self.profiling_enabled + else nullcontext() + ): + self._check_initialized() + + # If in safe mode, check at each iteration for whether a switch is required + if ( + torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + ): + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( + curr_device_id, + self.target_device_id, + curr_device_properties, + self.target_device_properties, + ): + device_id, _ = _select_rt_device( + curr_device_id, + self.target_device_id, + self.target_device_properties, + ) + + # Update current device + device = torch.device(device_id) + torch.cuda.set_device(device_id) + + contiguous_inputs = [ + tensor.to(device) for tensor in contiguous_inputs + ] + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + + if self.requires_output_allocator: # engine requires OA + if self.cudagraphs_enabled: + raise RuntimeError( + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." + ) + logger.debug("Using the dynamic allocator runtime mode.") + return run_output_allocator() + else: + if self.use_output_allocator_outputs: # users call OA context manager + if self.cudagraphs_enabled: + raise RuntimeError( + "Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two." + ) + logger.debug("Using the dynamic allocator runtime mode.") + return run_output_allocator() + else: + logger.debug( + f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}." + ) + return run_standard_execution() + + def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: + """ + Enable TensorRT profiling. After calling this function, TensorRT will report + time spent on each layer in stdout for each forward run. + """ + self._check_initialized() + + if not self.context.profiler: + self.context.profiler = trt.Profiler() if profiler is None else profiler + + self.profiling_enabled = True + + def disable_profiling(self) -> None: + """ + Disable TensorRT profiling. + """ + self._check_initialized() + torch.cuda.synchronize() + del self.context + self.context = self.engine.create_execution_context() + self.profiling_enabled = False + + def get_layer_info(self) -> str: + """ + Get layer info of the engine. Only support for TRT > 8.2. + """ + inspector = self.engine.create_engine_inspector() + engine_json: str = inspector.get_engine_information( + trt.LayerInformationFormat.JSON + ) + return engine_json + + def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + """ + Validates the input shapes of the forward function has changed + """ + # Representation of input shapes to a given model + # Shapes are concatenated as so: + # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) + tensor_inputs = [ + t if isinstance(t, torch.Tensor) else torch.tensor(t) + for t in inputs + ] + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs) + + # If the new shape key differs from the existing one, + # invalidate the old shape key and remove the CUDAGraph + if new_shape_key not in self.shape_key_to_cudagraph: + logger.debug(f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape.") + # self.shape_key = new_shape_key + return True, new_shape_key + + return False, new_shape_key diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6415ce11c3..fe4b781505 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -743,7 +743,11 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs) + tensor_inputs = [ + t if isinstance(t, torch.Tensor) else torch.tensor(t) + for t in inputs + ] + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs) # If the new shape key differs from the existing one, # invalidate the old shape key and remove the CUDAGraph