diff --git a/docsrc/index.rst b/docsrc/index.rst index 67fbdc56f5..4d28d77640 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -140,11 +140,10 @@ Model Zoo * :ref:`torch_compile_resnet` * :ref:`torch_compile_transformer` * :ref:`torch_compile_stable_diffusion` +* :ref:`compile_hf_models` * :ref:`torch_compile_gpt2` * :ref:`torch_export_gpt2` -* :ref:`torch_export_llama2` * :ref:`torch_export_sam2` -* :ref:`torch_export_flux_dev` * :ref:`notebooks` .. toctree:: @@ -155,11 +154,10 @@ Model Zoo tutorials/_rendered_examples/dynamo/torch_compile_resnet_example tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion + tutorials/compile_hf_models tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion tutorials/_rendered_examples/dynamo/torch_compile_gpt2 - tutorials/_rendered_examples/dynamo/torch_export_gpt2 - tutorials/_rendered_examples/dynamo/torch_export_llama2 tutorials/_rendered_examples/dynamo/torch_export_sam2 tutorials/_rendered_examples/dynamo/torch_export_flux_dev tutorials/notebooks diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst new file mode 100644 index 0000000000..f6da87b145 --- /dev/null +++ b/docsrc/tutorials/compile_hf_models.rst @@ -0,0 +1,218 @@ +.. _compile_hf_models: + +Compiling LLM models from Huggingface +====================================== + +This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference. +The code is available in the `tools/llm `_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance. + +.. note:: + This is an **experimental release** and APIs may change in future versions. + +.. note:: + The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm `_ directory. + +Overview of tools/llm Directory +------------------------------- + +The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface: + +* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking +* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization +* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass. +* **Testing Components**: Model-specific test files for validation +* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations + +Supported Models +---------------- +We have officially verified support for the following LLM families: + +.. list-table:: + :widths: 20 40 20 20 + :header-rows: 1 + + * - Model Series + - HuggingFace Model Card + - Precision + - KV Cache Support ? + * - GPT-2 + - gpt2 + - FP16, FP32 + - Yes + * - LLaMA 2 + - meta-llama/Llama-2-7b-chat-hf + - FP16, FP32 + - Yes + * - LLaMA 3.1 + - meta-llama/Llama-3.1-8B-Instruct + - FP16, FP32 + - Yes + * - LLaMA 3.2 + - | meta-llama/Llama-3.2-1B-Instruct + | meta-llama/Llama-3.2-3B-Instruct + - FP16, FP32 + - Yes + * - Qwen 2.5 + - | Qwen/Qwen2.5-0.5B-Instruct + | Qwen/Qwen2.5-1.5B-Instruct + | Qwen/Qwen2.5-3B-Instruct + | Qwen/Qwen2.5-7B-Instruct + - FP16, FP32 + - Yes + +Getting Started with run_llm.py +------------------------------- + +The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking. + +Basic Usage +^^^^^^^^^^^ + +.. code-block:: bash + + python tools/llm/run_llm.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --prompt "What is parallel programming?" \ + --precision FP16 \ + --num_tokens 128 \ + --cache static_v2 \ + --benchmark + +Key Arguments +^^^^^^^^^^^^^ + +* ``--model``: Name or path of the HuggingFace LLM +* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name +* ``--prompt``: Input prompt for text generation +* ``--precision``: Precision mode (``FP16``, ``FP32``) +* ``--num_tokens``: Number of output tokens to generate +* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching) +* ``--benchmark``: Enable benchmarking mode for performance comparison +* ``--enable_pytorch_run``: Also run and compare PyTorch baseline + + +Other Usage Examples +^^^^^^^^^^^^^^^^^^^^ +.. code-block:: bash + + # Compare different models performance + python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run + + # Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128 + python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128 + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128 + + # Test different caching approaches + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1 + python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2 + + # Compare FP16 vs FP32 performance + python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark + python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark + + +KV Caching in Torch-TensorRT +--------------------------------- + +We provide two versions of static KV caching: `static_cache_v1 `_ and `static_cache_v2 `_. +In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory. +The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config. + +Static Cache v1 +^^^^^^^^^^^^^^^^ + +The ``static_cache_v1.py`` implements KV cache in the model graph as follows: + +.. code-block:: python + + class StaticCacheV1Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + # Concatenate new key/value pairs with existing cache + new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2) + + # Compute attention using the updated cache + attn_output = torch._C._nn.scaled_dot_product_attention( + q, + new_key_cache[:, :, :end_idx, :], + new_value_cache[:, :, :end_idx, :], + dropout_p=0.0, + is_causal=is_causal + ) + + return attn_output, new_key_cache, new_value_cache + +In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index. +The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module. + +.. note:: + The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length. + For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate. + + +Static Cache v2 +^^^^^^^^^^^^^^^^ + +The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows: + +.. code-block:: python + + class StaticCacheV2Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True): + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + attn_output = torch._C._nn.scaled_dot_product_attention( + q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal + ) + + return attn_output, new_key_cache, new_value_cache + +In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value. +The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module. +The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``. + +After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``. +The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs. + +Generating Outputs +------------------- +We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. +There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching. + +The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache. +The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``. +We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model. + +SDPA Converter (sdpa_converter.py) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* Converts scaled dot-product attention operation using TRT Python API. +* Supports causal and standard self-attention. + +SDPA Registration (register_sdpa.py) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``. +* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation. + + +Limitations and Known Issues +---------------------------- + +* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported +* Some model architectures (e.g. Phi-4) have issues with exporting the torch model. + +Requirements +^^^^^^^^^^^^ + +* Torch-TensorRT 2.8.0 or later +* Transformers v4.52.3 \ No newline at end of file diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py index 7e8204c165..4aa49e4eca 100644 --- a/examples/dynamo/aot_plugin.py +++ b/examples/dynamo/aot_plugin.py @@ -1,3 +1,12 @@ +""" +.. _aot_plugin: + +AOT Plugin +========== + +This example demonstrates how to use an AOT plugin in Torch-TensorRT. +""" + import argparse from typing import Tuple, Union diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py deleted file mode 100644 index 4d34c58de4..0000000000 --- a/examples/dynamo/torch_export_gpt2.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -.. _torch_export_gpt2: - -Compiling GPT2 using the dynamo backend -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. -""" - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import export_llm, generate - -# %% - -# Define the parameters and initialize the model -MAX_TOKENS = 32 -DEVICE = torch.device("cuda:0") - -# Define the GPT2 model from hugging face -# kv_cache is not supported in Torch-TRT currently. -# CPU is used here so that GPU memory is reserved for TRT compilation. -with torch.no_grad(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = ( - AutoModelForCausalLM.from_pretrained( - "gpt2", - pad_token_id=tokenizer.eos_token_id, - use_cache=False, - attn_implementation="eager", - ) - .eval() - .half() - ) - -# %% -# Tokenize a sample input prompt and get pytorch model outputs -prompt = "I enjoy walking with my cute dog" -model_inputs = tokenizer(prompt, return_tensors="pt") -input_ids = model_inputs["input_ids"] - -# Auto-regressive generation loop for greedy decoding using PyTorch model -# We use a custom generate function which is very similar to the huggingface one. -pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - - -# %% -# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Export the GPT2 model into an ExportedProgram which is input of TRT compilation -# To compile the model in FP16, we do the following -# 1) Cast the model to FP16 via model.half() -# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation -# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch) -gpt2_ep = export_llm(model, input_ids, max_seq_len=1024) -trt_model = torch_tensorrt.dynamo.compile( - gpt2_ep, - inputs=[input_ids], - enabled_precisions={torch.float32}, - truncate_double=True, - device=DEVICE, - disable_tf32=True, - use_explicit_typing=True, - use_fp32_acc=True, -) - -# Auto-regressive generation loop for greedy decoding using TensorRT model -# We use a custom generate function which is very similar to the huggingface one. -# Move inputs to GPU -input_ids = input_ids.to(DEVICE) -trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - -# %% -# Decode the output sentences of PyTorch and TensorRT -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -print("=============================") -print( - "Pytorch model generated text: ", - tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), -) -print("=============================") -print( - "TensorRT model generated text: ", - tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), -) - -# Prompt : What is parallel programming ? - -# ============================= -# Pytorch model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that - -# ============================= -# TensorRT model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py deleted file mode 100644 index 2f3e3cba43..0000000000 --- a/examples/dynamo/torch_export_llama2.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -.. _torch_export_llama2: - -Compiling Llama2 using the dynamo backend -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model. -""" - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import export_llm, generate - -# %% -# Define the parameters and initialize the model -MAX_TOKENS = 32 -DEVICE = torch.device("cuda:0") - -# Define the Llama2 model from hugging face -# kv_cache is not supported in Torch-TRT currently. -# CPU is used here so that GPU memory is reserved for TRT compilation. -llama_path = "meta-llama/Llama-2-7b-chat-hf" -with torch.no_grad(): - model = ( - AutoModelForCausalLM.from_pretrained( - llama_path, use_cache=False, attn_implementation="eager" - ) - .eval() - .half() - ) - -tokenizer = AutoTokenizer.from_pretrained(llama_path) - -# %% -# Tokenize a sample input prompt and get pytorch model outputs -prompt = "What is dynamic programming?" -model_inputs = tokenizer(prompt, return_tensors="pt") -input_ids = model_inputs.input_ids - -# Auto-regressive generation loop for greedy decoding using PyTorch model -# We use a custom generate function which is very similar to the huggingface one. -pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - -# %% -# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Export the llama2 model into an ExportedProgram which is input of TRT compilation -# To compile the model in FP16, we do the following -# 1) Cast the model to FP16 via model.half() -# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation -# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch) -llama2_ep = export_llm(model, input_ids, max_seq_len=64) -trt_model = torch_tensorrt.dynamo.compile( - llama2_ep, - inputs=[input_ids], - enabled_precisions={torch.float32}, - truncate_double=True, - device=DEVICE, - disable_tf32=True, - use_explicit_typing=True, - use_fp32_acc=True, -) - -# Auto-regressive generation loop for greedy decoding using TensorRT model -# We use a custom generate function which is very similar to the huggingface one. -# Move inputs to GPU -input_ids = input_ids.to(DEVICE) -trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) - -# %% -# Decode the output sentences of PyTorch and TensorRT -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -print("=============================") -print( - "Pytorch model generated text: ", - tokenizer.batch_decode( - pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0], -) -print("=============================") -print( - "TensorRT model generated text: ", - tokenizer.batch_decode( - trt_gen_tokens, - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - )[0], -) - - -# Prompt : What is dynamic programming? - -# ============================= -# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and - -# ============================= -# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py deleted file mode 100644 index 25ad99c12d..0000000000 --- a/examples/dynamo/utils.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from transformers import StoppingCriteriaList -from transformers.generation.stopping_criteria import ( - EosTokenCriteria, - MaxLengthCriteria, -) - - -def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): - """ - Exports the LLM model into an ExportedProgram with dynamic shapes. - In the case of guard failures due to some PyTorch kernel implements, we also - try to re-export the graph by expressing them as runtime assert nodes - """ - with torch.no_grad(): - # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 - seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) - try: - print("Trying to export the model using torch.export.export()..") - # strict=False only enables aotautograd tracing and excludes dynamo. - ep = torch.export.export( - model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False - ) - except: - print( - "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" - ) - # This API is used to express the constraint violation guards as asserts in the graph. - ep = torch.export._trace._export( - model, - (inputs,), - dynamic_shapes=({1: seq_len},), - strict=False, - allow_complex_guards_as_runtime_asserts=True, - ) - - return ep - - -def generate(model, input_seq, max_tokens, eos_token_id): - """ - Greedy decoding of the model. This generates up to max_tokens. - """ - # Max length of output seq = current input_seq length + max_tokens allowed to generate - max_output_seq_length = input_seq.shape[1] + max_tokens - stopping_criteria = StoppingCriteriaList( - [ - MaxLengthCriteria(max_length=max_output_seq_length), - EosTokenCriteria(eos_token_id=eos_token_id), - ] - ) - - while True: - outputs = model(input_seq) - logits = outputs.logits - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1) - input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) - # TODO: Handle batch in this check - if stopping_criteria(input_seq, logits).item(): - break - - return input_seq diff --git a/examples/dynamo/weight_streaming_example.py b/examples/dynamo/weight_streaming_example.py index e1076a9e75..601292ba95 100644 --- a/examples/dynamo/weight_streaming_example.py +++ b/examples/dynamo/weight_streaming_example.py @@ -32,7 +32,43 @@ import torch import torch_tensorrt from transformers import AutoModelForCausalLM -from utils import export_llm + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep def time_generate(model, inputs, output_seq_length, iterations=10): diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 830faf3373..22ecbce9fc 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -790,6 +790,28 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." ) + # Store the original input spec for later use + original_in_spec = getattr(gm, "_in_spec", None) + original_out_spec = getattr(gm, "_out_spec", None) + + # Function to preserve and restore module specs + def preserve_module_specs( + in_spec: Any, out_spec: Any, target_module: torch.fx.GraphModule + ) -> None: + """ + Applies input and output specs to the target module. + + Args: + in_spec: The input spec to apply + out_spec: The output spec to apply + target_module: The module to apply specs to + """ + # Apply specs to target module + if in_spec is not None: + target_module._in_spec = in_spec + if out_spec is not None: + target_module._out_spec = out_spec + # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False # If specified, try using the fast partitioner and fall back to the global one on failure @@ -837,6 +859,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: continue submodule_node_dict[node.name] = node + preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e542f1d417..f243d091a4 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1935,6 +1935,7 @@ def aten_ops_minimum( ) +@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True) def aten_ops_sub( diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6415ce11c3..84711f154e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -743,7 +743,14 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs) + tensor_inputs = [] + for t in inputs: + if not isinstance(t, torch.Tensor): + return True + tensor_inputs.append(t) + new_shape_key = "".join( + str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + ) # If the new shape key differs from the existing one, # invalidate the old shape key and remove the CUDAGraph diff --git a/tools/llm/README.md b/tools/llm/README.md new file mode 100644 index 0000000000..3fd55bc060 --- /dev/null +++ b/tools/llm/README.md @@ -0,0 +1,66 @@ +# Optimizing LLMs in Torch-TensorRT + +This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions. + +### Key Features + +- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc. +- **Precision Modes:** Supports FP16, BF16, and FP32. +- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding. +- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends. +- **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT. + + +### Supported Models + +We have officially verified support for the following models: + +| Model Series | HF Model Card | Precision | KV Cache Supported ? | +|--------------|---------------|-----------|-------------------| +| GPT-2 | gpt2
gpt2-medium | FP16, FP32 | Yes | +| LLaMA 2 | meta-llama/Llama-2-7b-chat-hf | FP16, FP32 | Yes | +| LLaMA 3.1 | meta-llama/Llama-3.1-8B-Instruct | FP16, FP32 | Yes | +| LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes | +| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes | + + +### Usage + +The main entry point is : `run_llm.py` + +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark +``` + +#### Key Arguments + +- `--model`: Name or path of the HuggingFace LLM. +- `--tokenizer`: (Optional) Tokenizer name; defaults to model. +- `--prompt`: Input prompt for generation. +- `--precision`: Precision mode (`FP16`, `FP32`). +- `--num_tokens`: Number of output tokens to generate. +- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). +- `--benchmark`: Enable benchmarking mode. +- `--enable_pytorch_run`: Also run and compare PyTorch baseline. + +### Caching Strategies + +- **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse. +- **No Cache:** Standard autoregressive decoding. + +Please read our tutorial on how static cache is implemented. + +## Extension + +This codebase can be extended to +- Add new models by specifying their HuggingFace name. +- Implement new cache strategies by adding FX graph passes. +- Customize SDPA conversion for new attention mechanisms. + +## Limitations +- We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet. + +## Requirements + +- Torch-TensorRT 2.8.0 +- Transformers v4.52.3 \ No newline at end of file diff --git a/tools/llm/cache_utils.py b/tools/llm/cache_utils.py new file mode 100644 index 0000000000..7089d9a220 --- /dev/null +++ b/tools/llm/cache_utils.py @@ -0,0 +1,202 @@ +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import tensorrt +import torch +import torch_tensorrt +from torch._export.utils import _detect_fake_mode_from_gm +from torch._ops import OpOverloadPacket +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Target +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._pytree import _LEAF_SPEC + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.ops.higher_order.cond, enabled=True, supports_dynamic_shapes=True +) +def cond_converter( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: + """ + Converter for torch.ops.higher_order.cond operation to TensorRT. + + This function handles the conversion of PyTorch's conditional operation to TensorRT. + The conditional operation selects between two tensors based on a boolean predicate. + + Args: + ctx (torch_tensorrt.dynamo.conversion.ConversionCtx): The conversion context + target (Target): The target operation to convert + args (Tuple[Argument, ...]): The arguments to the operation + kwargs (Dict[str, Argument]): The keyword arguments to the operation + name (str): The name to give to the TensorRT layer + + Returns: + Union[tensorrt.ITensor, Sequence[tensorrt.ITensor]]: The converted TensorRT tensor(s) + """ + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = args[0], args[1], args[2] + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + output = output_layer.get_output(0) + + return output + + +def get_kv_nodes(gm): + """ + Get the key and value nodes from the graph. + """ + kv_nodes = [] + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch._C._nn.scaled_dot_product_attention + ): + q_node, k_node, v_node = node.args[:3] + kv_nodes.append((k_node, v_node)) + return kv_nodes + + +def get_random_tensor_from_node(node: Node) -> torch.Tensor: + """ + Creates a random tensor based on the shape information in a node's metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + node: A torch.fx.Node object with metadata containing tensor information + + Returns: + A random tensor with shape matching the node's metadata, or None if no valid + tensor information is found + """ + if "val" not in node.meta: + raise ValueError( + f"No tensor information found in node metadata for node: {node}" + ) + + fake_tensor = node.meta["val"] + shape = [] + + # Iterate through each dimension and handle symbolic dimensions + for dim in fake_tensor.shape: + if isinstance(dim, torch.SymInt): + # Extract the maximum value from the shape environment + max_val = dim.node.hint + shape.append(max_val) + else: + shape.append(dim) + + # Create a random tensor with the determined shape + dtype = fake_tensor.dtype + device = fake_tensor.device + random_tensor = torch.rand(shape, dtype=dtype, device=device) + + return random_tensor + + +def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]: + """ + Creates random tensors based on the shape information in node metadata. + For symbolic dimensions, extracts the maximum value from the shape environment. + + Args: + nodes: List of torch.fx.Node objects with metadata + + Returns: + List of random tensors with shapes matching the nodes' metadata + """ + random_tensors = [] + + for node in nodes: + if isinstance(node, Node): + node_tensor = get_random_tensor_from_node(node) + elif isinstance(node, tuple): + node_tensor_list = [] + for n in node: + random_tensor = get_random_tensor_from_node(n) + node_tensor_list.append(random_tensor) + node_tensor = tuple(node_tensor_list) + + random_tensors.append(node_tensor) + + return random_tensors + + +def add_graph_input( + gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None +) -> Node: + """Add a graph input to the given GraphModule and return the newly created node. + + NOTE: function does NOT do any graph canonicalization. This is left to the user! + + Args: + gm (GraphModule): The GraphModule to add the input to. + name (str): The name of the input. + val (torch.Tensor): An example tensor to use for the input. + dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET] + """ + # check that no dynamic shape is provided... + if dynamic_shape: + raise NotImplementedError("Dynamic shape not supported for adding graph inputs") + + # extract graph and input spec + graph: Graph = gm.graph + + in_spec = graph._codegen.pytree_info.in_spec + in_spec_for_args = in_spec.children_specs[0] + orig_args = graph._codegen.pytree_info.orig_args + assert in_spec_for_args.type is tuple + + # insert input node after currently last input node + node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1] + with graph.inserting_after(node_last_input): + in_node = graph.placeholder(name) + in_spec_for_args.children_specs.append(_LEAF_SPEC) + orig_args.append(f"arg_{name}") + + # update pytree info recursively with __post_init__ starting at leaves + def call_post_init(spec): + for child_spec in spec.children_specs: + call_post_init(child_spec) + spec.__post_init__() + + call_post_init(in_spec) + + # set fake tensor information if all required information is available + fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm) + if fake_mode and val is not None and isinstance(val, torch.Tensor): + if isinstance(val, FakeTensor): + fake_tensor = val + else: + fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True) + in_node.meta["val"] = fake_tensor + in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor) + + # return new node... + return in_node + + +def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool: + """Check if the node is a call to one of the ops.""" + if node.op != "call_function": + return False + # check if it's a single op that's provided + if isinstance(ops, OpOverloadPacket): + ops = [ops] + + # check if it's the op itself instead of an overload + if any(node.target == op for op in ops): + return True + + return False + + +def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: + input_nodes: List[Node] = graph.find_nodes(op="placeholder") + output_nodes: List[Node] = graph.find_nodes(op="output") + return (input_nodes, output_nodes) diff --git a/examples/dynamo/register_sdpa.py b/tools/llm/register_sdpa.py similarity index 86% rename from examples/dynamo/register_sdpa.py rename to tools/llm/register_sdpa.py index 7436f31939..c3c76e0f2d 100644 --- a/examples/dynamo/register_sdpa.py +++ b/tools/llm/register_sdpa.py @@ -19,11 +19,13 @@ # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention # This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) TORCH_TRT_DECOMPOSITIONS.pop( - torch.ops.aten._scaled_dot_product_efficient_attention.default + torch.ops.aten._scaled_dot_product_efficient_attention.default, None +) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_flash_attention.default, None ) -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, @@ -59,6 +61,7 @@ def replace_variants_of_sdpa( elif len(node.args) == 5: query, key, value, attn_mask, is_causal = node.args dropout_p = 0.0 + else: raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" @@ -71,6 +74,8 @@ def replace_variants_of_sdpa( query, key, value, dropout_p, is_causal, return_debug_mask = ( node.args ) + if len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args elif len(node.args) == 3: query, key, value = node.args dropout_p = 0.0 @@ -79,20 +84,21 @@ def replace_variants_of_sdpa( raise ValueError( f"Unexpected number of arguments for {node.target} in the graph" ) - if attn_mask is not None: - logger.warning( - f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." - ) - - modified_input_args = (query, key, value, None, dropout_p, is_causal) + logger.warning( + f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." + ) + modified_input_args = (query, key, value, None, dropout_p, True) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None)}, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + }, ) # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. @@ -113,7 +119,7 @@ def replace_variants_of_sdpa( # Clean up the graph clean_up_graph_after_modifications(gm) - logger.info( + logger.debug( "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" ) return gm diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py new file mode 100644 index 0000000000..7a98d2a2c0 --- /dev/null +++ b/tools/llm/run_llm.py @@ -0,0 +1,343 @@ +""" +.. _torch_export_gpt2: + +Compiling GPT2 using the dynamo backend +========================================================== + +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" + +import argparse +import copy +import os +import timeit +from contextlib import nullcontext + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from register_sdpa import * +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import ( + export_llm, + generate, + generate_with_dynamic_cache, + generate_with_static_cache, + recordStats, + time_generate, +) + +DEVICE = torch.device("cuda:0") + + +def get_model(args): + with torch.no_grad(): + # Supported list of models: + # - meta-llama/Llama-3.2-1B-Instruct + # - meta-llama/Llama-3.2-3B-Instruct + # - meta-llama/Llama-3.1-8B-Instruct + # - Qwen/Qwen2.5-1.5B-Instruct + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + ) + .eval() + .cuda() + ) + if args.precision == "FP16": + model = model.to(torch.float16) + elif args.precision == "BF16": + model = model.to(torch.bfloat16) + else: + model = model.to(torch.float32) + + return model + + +def compile_torchtrt(model, input_ids, args): + max_seq_len = input_ids.shape[1] + args.num_tokens + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, + # truncate_double=True, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +def measure_perf(trt_model, input_signature, backend_name): + # Measure average time for 10 iterations + import timeit + + import numpy as np + + total_time = 0 + iterations = 10 + + print("Running warmup iteration...") + # Warmup run + _ = trt_model(*input_signature) + torch.cuda.synchronize() + + print(f"Measuring performance over {iterations} iterations...") + for i in range(iterations): + start_time = timeit.default_timer() + _ = trt_model(*input_signature) + torch.cuda.synchronize() + end_time = timeit.default_timer() + iter_time = end_time - start_time + total_time += iter_time + # print(f"Iteration {i+1}: {iter_time:.4f} seconds") + + avg_time = total_time / iterations + print( + f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds" + ) + print( + f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second" + ) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run inference on a model with random input values" + ) + arg_parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model", + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", type=str, default="What is parallel programming ?", help="Prompt" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + arg_parser.add_argument( + "--iterations", type=int, default=5, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--min_block_size", type=int, default=1, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--num_tokens", + type=int, + default=128, + help="no. of output tokens to be generated", + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size used for benchmarking" + ) + arg_parser.add_argument( + "--isl", + type=int, + default=2048, + help="Input sequence length used for benchmarking", + ) + arg_parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Enable pytorch run (default: False)", + ) + arg_parser.add_argument( + "--cache", + type=str, + default="", + help="Type of KV cache to use. Options: static_v1, static_v2, dynamic", + ) + arg_parser.add_argument( + "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmark (default: False)" + ) + + args = arg_parser.parse_args() + with torch.inference_mode(): + model = get_model(args) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + + # Prepare input for benchmarking or evaluation + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(model.device) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + # Pyt + pyt_gen_tokens = None + pyt_timings = None + pyt_stats = None + + if args.enable_pytorch_run: + pyt_gen_tokens = generate( + model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + ) + if args.benchmark: + pyt_timings = time_generate( + generate, + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = recordStats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + elif args.cache == "dynamic": + import dynamic_cache + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + elif args.cache == "dynamic": + trt_gen_tokens = generate_with_dynamic_cache( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate_with_dynamic_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if args.benchmark: + trt_stats = recordStats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) + print("===================== \n") + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/examples/dynamo/sdpa_converter.py b/tools/llm/sdpa_converter.py similarity index 51% rename from examples/dynamo/sdpa_converter.py rename to tools/llm/sdpa_converter.py index 903324dff5..47083c7b48 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/tools/llm/sdpa_converter.py @@ -62,25 +62,15 @@ def scaled_dot_product_attention( ) -> TRTTensor: # TODO: Handle attn_mask and is_causal arguments in the future query, key, value, attn_mask, dropout_p, is_causal = args - logger.info( - "Ignoring attn_mask and is_causal arguments provided by the original graph. " - "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " - "and for generate phase, is_causal=False since we pass only 1 input token at a time" - ) # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN + is_causal = True # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - mm = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) + use_fp32_acc = kwargs.get("use_fp32_acc", False) + query_dtype = query.dtype + if scale is None: scale = query.shape[-1] if scale < 0: @@ -90,80 +80,106 @@ def scaled_dot_product_attention( else: # static shape sqrt_scaled = math.sqrt(scale) - scaled = impl.elementwise.div( + key = impl.elementwise.div( ctx, target, source_ir, name + "_scale", - mm, + key, sqrt_scaled, ) else: - scaled = impl.elementwise.mul( + key = impl.elementwise.mul( ctx, target, source_ir, name + "_scale", - mm, + key, scale, ) - # If is_causal is True, we need to generate a causal mask - if is_causal: - L, S = query.shape[-2], key.shape[-2] - if L >= 0 and S >= 0: - # static shape - attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) - temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) - attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) - attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") - else: - # if any of the L or S is dynamic shape - if L < 0: - L = impl.shape.shape( - ctx, target, source_ir, name + "_shape_0", query, 2 - ) - if S < 0: - S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - - # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir - ) - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) - - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + if use_fp32_acc and query_dtype == trt.float16: + query = cast_trt_tensor( + ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir + ) + key = cast_trt_tensor( + ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir ) + + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + if use_fp32_acc: + mm = cast_trt_tensor( + ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir + ) + + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") else: - scaled_add_attn_bias = scaled + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - # Create a if condition to check if is_causal is True - if isinstance(is_causal, TRTTensor): - if_layer = ctx.net.add_if_conditional() - condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled - if_layer.set_condition(condition) - output_layer = if_layer.add_output(true_branch, false_branch) - scaled_add_attn_bias = output_layer.get_output(0) + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) + if use_fp32_acc: + softmax = cast_trt_tensor( + ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir + ) + value = cast_trt_tensor( + ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir + ) out = impl.matmul.matrix_multiply( ctx, target, @@ -172,5 +188,9 @@ def scaled_dot_product_attention( softmax, value, ) + if use_fp32_acc: + out = cast_trt_tensor( + ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir + ) return out diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py new file mode 100644 index 0000000000..a87495953d --- /dev/null +++ b/tools/llm/static_cache_v1.py @@ -0,0 +1,277 @@ +import logging +from typing import List, Tuple + +import torch +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from torch.fx import Node +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info + +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, .. + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + input_ids_meta = input_nodes[0].meta["val"] + seq_len = input_ids_meta.shape[1] + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + + +def insert_kv_slicing_before_sdpa( + gm, + incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], + start_idx_input: Node, + end_idx_input: Node, +): + """ + Insert slicing operations before each scaled_dot_product_attention operation. + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + assert ( + len(sdpa_node.args) == 6 + ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + kv_cache_for_sdpa_node = [] + new_keys_values = [] + for key_or_value, current_key_or_value_node in zip( + [incoming_key, incoming_value], [k_node, v_node] + ): + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={}, + ) + slice_2 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_1, 1), + kwargs={}, + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={}, + ) + slice_4 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_3, 3), + kwargs={}, + ) + # =============================================== # + # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(key_or_value,), + kwargs={}, + ) + slice_6 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_5, 1), + kwargs={}, + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={}, + ) + slice_8 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_7, 3), + kwargs={}, + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + cat = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_key_or_value_node, slice_8], 2), + kwargs={}, + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + cat.meta.update(key_or_value.meta) + kv_cache_for_sdpa_node.append(cat) + # =============================================== # + # Get the current key and value by indexing the KV cache + slice_9 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={} + ) + slice_10 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_9, 1), + kwargs={}, + ) + slice_11 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_10, 2, None, end_idx_input), + kwargs={}, + ) + slice_12 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_11, 3), + kwargs={}, + ) + new_keys_values.append(slice_12) + + kv_cache_for_graph.extend(kv_cache_for_sdpa_node) + + sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + ( + attn_mask, + dropout_p, + True, + ) + + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v1( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa( + gm, kv_inputs, start_idx_input, end_idx_input + ) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + + return gm diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py new file mode 100644 index 0000000000..ad386d39f2 --- /dev/null +++ b/tools/llm/static_cache_v2.py @@ -0,0 +1,290 @@ +import logging +from typing import List, Tuple + +import torch +import torch.utils._pytree as pytree +from cache_utils import add_graph_input, create_random_output_tensors, get_kv_nodes +from torch.fx import Node +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.utils import extract_var_range_info + +logger = logging.getLogger(__name__) + +SDPA_OP = torch._C._nn.scaled_dot_product_attention + + +def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]): + """ + Modifies the graph to add query, key, and value tensors as outputs. + + This function identifies all scaled dot-product attention (SDPA) operations + in the graph, creates copies of their query, key, and value inputs, and adds + these copies to the graph's outputs. This allows for accessing these tensors + externally, which is useful for operations like key-value caching. + + Args: + graph: The torch.fx.Graph to modify + + Returns: + None. The graph is modified in-place. + """ + output_node = next(node for node in gm.graph.nodes if node.op == "output") + + # Get the current output args (typically a tuple) + current_outputs = output_node.args[0] + + # If the current output is a tuple, extend it with our new outputs + if isinstance(current_outputs, tuple): + new_outputs = current_outputs + tuple(kv_cache_for_graph) + else: + # If there's only one output or it's not a tuple, create a new tuple + new_outputs = (current_outputs,) + tuple(kv_cache_for_graph) + + gm.graph.output(new_outputs) + gm.graph.erase_node(output_node) + + return new_outputs + + +def add_kv_cache_inputs(gm, fixed_kv: bool = True): + """ + Add key-value tensors, index parameters as inputs to the graph. + + Args: + gm: The GraphModule to modify + fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True. + + Returns: + A tuple containing: + - List of (k_input, v_input) node pairs for each SDPA operation + - start_idx input node for slicing operations + - end_idx input node for slicing operations + """ + + def get_static_tensor(tensor: torch.Tensor): + key_shape = [] + for dim in tensor.shape: + if isinstance(dim, torch.SymInt): + min_max_opt = extract_var_range_info(dim) + key_shape.append(min_max_opt["max"]) + else: + key_shape.append(dim) + + static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device) + return static_tensor + + keys_values = get_kv_nodes(gm) + + kv_inputs = [] + for idx, key_value in enumerate(keys_values): + k_val = key_value[0].meta["val"] + v_val = key_value[1].meta["val"] + if fixed_kv: + k_val = get_static_tensor(k_val) + v_val = get_static_tensor(v_val) + + # Add new inputs using add_graph_input + k_input = add_graph_input(gm, key_value[0].name + "_k_input", k_val) + v_input = add_graph_input(gm, key_value[1].name + "_v_input", v_val) + kv_inputs.append((k_input, v_input)) + + # Add start_idx and end_idx as inputs + start_idx_input = add_graph_input(gm, "start_idx", torch.tensor(0)) + end_idx_input = add_graph_input(gm, "end_idx", torch.tensor(1)) + + # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + # Get the third last input which should be the last value cache node and store the max_seq_len + input_ids_meta = input_nodes[-3].meta["val"] + seq_len = input_ids_meta.shape[2] + + if isinstance(seq_len, torch.SymInt): + min_max_opt = extract_var_range_info(seq_len) + max_seq_len = min_max_opt["max"] + else: + max_seq_len = seq_len + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + shape_env = ShapeEnv() + # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive + start_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(start_idx_unbacked_symint >= 0) + torch._check(start_idx_unbacked_symint <= max_seq_len) + + end_idx_unbacked_symint = shape_env.create_unbacked_symint() + torch._check(end_idx_unbacked_symint >= 0) + torch._check(end_idx_unbacked_symint <= max_seq_len) + # Set the symbolic ints as the metadata for start_idx and end_idx inputs + start_idx_input.meta["val"] = start_idx_unbacked_symint + end_idx_input.meta["val"] = end_idx_unbacked_symint + + return kv_inputs, start_idx_input, end_idx_input + + +def create_kv_cache_update_nodes( + gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input +): + """ + Create slicing and concatenation nodes for KV cache update. + + This function creates the necessary slicing and concatenation nodes to update the KV cache + during the generation process. It takes the SDPA node, the current KV cache node, and the + incoming KV cache node as input. + Returns: + for a particular SDPA node, a tuple containing: + - List of new current KV nodes + - List of updated incoming KV cache nodes + + """ + + # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + with gm.graph.inserting_before(sdpa_node): + slice_1 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={}, + ) + slice_2 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={} + ) + slice_3 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_2, 2, None, start_idx_input), + kwargs={}, + ) + slice_4 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={} + ) + # Concat key_cache[:,:,:start_idx,:] with current key (k) + concat_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([slice_4, current_kv_node], 2), + kwargs={}, + ) + + # =============================================== # + # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim + slice_5 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(incoming_kv_node,), + kwargs={}, + ) + slice_6 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={} + ) + slice_7 = gm.graph.create_node( + "call_function", + torch.ops.aten.slice.Tensor, + args=(slice_6, 2, end_idx_input), + kwargs={}, + ) + slice_8 = gm.graph.create_node( + "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={} + ) + # =============================================== # + # Concatenate the sliced tensors to build KV cache + new_incoming_keys_or_values = gm.graph.create_node( + "call_function", + torch.ops.aten.cat.default, + args=([concat_keys_or_values, slice_8], 2), + kwargs={}, + ) + # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph + new_incoming_keys_or_values.meta.update(incoming_kv_node.meta) + + return concat_keys_or_values, new_incoming_keys_or_values + + +def insert_kv_slicing_before_sdpa( + gm, + incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], + start_idx_input: Node, + end_idx_input: Node, +): + """ + Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic: + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2) + out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal) + """ + # Find all nodes with scaled_dot_product_attention + sdpa_nodes = [] + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == SDPA_OP: + sdpa_nodes.append(node) + kv_cache_for_graph = [] + for idx, sdpa_node in enumerate(sdpa_nodes): + assert ( + len(sdpa_node.args) == 6 + ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments" + q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args + incoming_key, incoming_value = incoming_keys_values[idx] + # For keys + new_current_key_node, new_incoming_key_cache_node = ( + create_kv_cache_update_nodes( + gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input + ) + ) + # For values + new_current_value_node, new_incoming_value_cache_node = ( + create_kv_cache_update_nodes( + gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input + ) + ) + + # Store the KV cache nodes for the current SDPA node + kv_cache_for_graph.extend( + [new_incoming_key_cache_node, new_incoming_value_cache_node] + ) + + # Update the SDPA node arguments with current key and value nodes + sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + ( + attn_mask, + dropout_p, + True, + ) + + # kv_cache_for_graph.extend([k_node, v_node]) + return gm, kv_cache_for_graph + + +@_aten_lowering_pass +def insert_static_cache_v2( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Insert KV cache ops in the graph""" + """Perform insertion of kv-caches and attention kernel.""" + # Add static key and value as inputs to the graph + kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True) + + # Build and update the KV cache using computed KV inputs for current token and + # incoming keys and values from previous tokens (which were added as inputs) + gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa( + gm, kv_inputs, start_idx_input, end_idx_input + ) + + # Call the function to add KV as outputs + logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph) + + gm = clean_up_graph_after_modifications(gm) + + new_output_tensors = create_random_output_tensors(logits_keys_values) + + new_out_spec = pytree.tree_flatten(new_output_tensors)[1] + gm._out_spec = new_out_spec + + logger.debug("After inserting KV cache into the graph: " + str(gm.graph)) + return gm diff --git a/tools/llm/test_llama_components.py b/tools/llm/test_llama_components.py new file mode 100644 index 0000000000..ef7e59cd72 --- /dev/null +++ b/tools/llm/test_llama_components.py @@ -0,0 +1,603 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import argparse +import os +import sys +from contextlib import nullcontext + +import torch.nn as nn +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from transformers import AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +# llama2_model_name = "meta-llama/Llama-2-7b-hf" +llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct" +llama_model = ( + AutoModelForCausalLM.from_pretrained( + llama3_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .cuda() +) +LLAMA_CONFIG = llama_model.config + + +def test_llama_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + # model = LlamaAttentionBlock().eval().cuda().to(DTYPE) + model = llama_model.model.layers[0].self_attn.to(DTYPE) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + from torch.export._trace import _export + + # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False) + ep = _export( + model, + args=(hidden_states, position_embeddings, None), + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_embeddings, None) + if isinstance(pyt_output, tuple): + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_llama_attention_with_static_cache(args): + class LlamaAttentionBlock(nn.Module): + def __init__(self): + super().__init__() + self.config = LLAMA_CONFIG + self.attn = LlamaAttention(config=self.config, layer_idx=0) + + def forward(self, hidden_states, position_embeddings): + attn_output, attn_weights = self.attn( + hidden_states, position_embeddings, None + ) + return attn_output + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = llama_model.model.layers[0].self_attn.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + ) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings, None) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + import static_cache_v2 + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[ + hidden_states, + position_embeddings, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ], + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, _, key_cache, value_cache = trt_model( + hidden_states, + position_embeddings, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = ( + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + ) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = ( + torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), + torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), + ) + + is_causal = False + out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None) + out_trt, _, key_cache, value_cache = trt_model( + hidden_states_curr, + position_embeddings_curr, + None, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_decoder(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + # llama3 + hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + torch.randn((1, 6, 64), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export( + model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings], + enabled_precisions=enabled_precisions, + debug=args.debug, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + ) + trt_output = trt_model(hidden_states, position_embeddings) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + + +def test_llama_decoder_with_static_cache(args): + + class LlamaDecoderLayerBlock(nn.Module): + def __init__(self, model): + super().__init__() + self.config = LLAMA_CONFIG + self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0) + self.model = model + + def forward(self, hidden_states, position_embeddings): + return self.model(hidden_states, position_embeddings=position_embeddings) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE)) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + torch.randn((1, ISL, 64), dtype=DTYPE).cuda(), + ) + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(hidden_states, position_embeddings) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len})) + ep = torch.export.export( + model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes + ) + import static_cache_v2 + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[ + hidden_states, + position_embeddings, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ], + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model( + hidden_states, + position_embeddings, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda() + position_embeddings_curr = ( + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + torch.randn((1, 1, 64), dtype=DTYPE).cuda(), + ) + # Concatenate the current hidden_states with the previous ones + hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1) + position_embeddings_full = ( + torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), + torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1), + ) + + is_causal = False + out_no_cache = model(hidden_states_full, position_embeddings_full) + + out_trt, key_cache, value_cache = trt_model( + hidden_states_curr, + position_embeddings_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache[0][:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + hidden_states = hidden_states_full + position_embeddings = position_embeddings_full + + +def test_llama_model(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + + pyt_output = model(input_ids, position_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"position_ids": position_ids} + from torch.export._trace import _export + + ep = _export( + model, + args=(input_ids,), + kwargs=kwarg_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + trt_output = trt_model(input_ids, position_ids) + + print( + f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}" + ) + # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}") + breakpoint() + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_llama_model_with_static_cache(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + model = llama_model.model.to(DTYPE) + + # Inputs + ISL = 2048 + NUM_TOKENS = 128 + OSL = ISL + NUM_TOKENS + input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda() + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda() + key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE) + start_idx = 0 + end_idx = ISL + is_causal = True + + pyt_output = model(input_ids) + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, {1: seq_len}) + kwarg_inputs = {"input_ids": input_ids, "position_ids": position_ids} + ep = torch.export.export( + model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes + ) + + import static_cache_v2 + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=[], + kwarg_inputs=kwarg_inputs, + enabled_precisions=enabled_precisions, + disable_tf32=True, + debug=args.debug, + # offload_module_to_cpu=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + use_python_runtime=True, + ) + + # Test Prefill + trt_output, key_cache, value_cache = trt_model( + input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal + ) + pyt_output = pyt_output.last_hidden_state + print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]") + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda() + position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda() + + # Concatenate the current hidden_states with the previous ones + input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1) + position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1) + is_causal = False + kwarg_inputs = {"input_ids": input_ids_full, "position_ids": position_ids_full} + out_no_cache = model(**kwarg_inputs) + + out_trt, key_cache, value_cache = trt_model( + input_ids_curr, + position_ids_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + out_pyt = out_no_cache.last_hidden_state[:, -1:, :] + print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}") + input_ids = input_ids_full + position_ids = position_ids_full + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", type=str, default="FP16", help="Precision (default: FP16)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_llama_attention(args) + # test_llama_decoder(args) + test_llama_model(args) + # test_llama_attention_with_static_cache(args) + # test_llama_decoder_with_static_cache(args) + # test_llama_model_with_static_cache(args) diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py new file mode 100644 index 0000000000..60482bf22d --- /dev/null +++ b/tools/llm/test_qwen2.5_components.py @@ -0,0 +1,193 @@ +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +import argparse +import os +import sys +from contextlib import nullcontext + +import torch.nn as nn +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from transformers import AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from register_sdpa import * + +ATOL = 1e-5 +RTOL = 1e-5 + + +qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct" +qwen2_5_model = ( + AutoModelForCausalLM.from_pretrained( + qwen2_5_model_name, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .cuda() +) +QWEN_CONFIG = qwen2_5_model.config + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_qwen_apply_rotary_pos_emb(args): + class QwenApplyRotaryPosEmb(nn.Module): + def __init__(self): + super().__init__() + + def rotate_half(self, x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + def forward(self, q, k, cos, sin, unsqueeze_dim=1): + return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim) + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE) + # Shapes for Qwen 2.5 + q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda() + cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda() + + pyt_output = model(q, k, cos, sin) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len}) + ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[q, k, cos, sin], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(q, k, cos, sin) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +def test_qwen_attention(args): + + DTYPE = torch.float32 + if args.precision == "FP16": + DTYPE = torch.float16 + elif args.precision == "BF16": + DTYPE = torch.bfloat16 + + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE) + # qwen2.5 + hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda() + position_embeddings = ( + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + torch.randn((1, 5, 128), dtype=DTYPE).cuda(), + ) + + pyt_output = model(hidden_states, position_embeddings, None) + + seq_len = torch.export.Dim("seq_len", min=2, max=2176) + dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None) + ep = torch.export.export( + model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes + ) + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[hidden_states, position_embeddings, None], + enabled_precisions=enabled_precisions, + disable_tf32=True, + use_fp32_acc=use_fp32_acc, + use_explicit_typing=use_explicit_typing, + debug=args.debug, + ) + trt_output = trt_model(hidden_states, position_embeddings, None) + + if isinstance(pyt_output, tuple): + print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt") + assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL) + else: + print_diff(pyt_output, trt_output, "Diff b/w pyt and trt") + assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_qwen_apply_rotary_pos_emb(args) + test_qwen_attention(args) diff --git a/tools/llm/test_static_cache.py b/tools/llm/test_static_cache.py new file mode 100644 index 0000000000..603f84d3a6 --- /dev/null +++ b/tools/llm/test_static_cache.py @@ -0,0 +1,478 @@ +import argparse +import os +import sys +from contextlib import nullcontext + +import torch +import torch.nn as nn +import torch_tensorrt +from torch.export import export +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from transformers import AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import register_sdpa + +ATOL = 1e-5 +RTOL = 1e-5 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class DynamicCacheModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, k1, v1, flag): + def true_fn(q, k, v, k1, v1): + k_new = torch.cat((k, k1), dim=2) + v_new = torch.cat((v, v1), dim=2) + return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new) + + def false_fn(q, k, v, k1, v1): + return torch._C._nn.scaled_dot_product_attention(q, k, v) + + out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1)) + + return 2 * out + + +class ModelNoCache(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + return torch._C._nn.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=True + ) + + +class StaticCacheModel(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True + ): + new_key_cache = torch.cat( + (key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2 + ) + new_value_cache = torch.cat( + (value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2 + ) + attn_output = torch._C._nn.scaled_dot_product_attention( + q, + new_key_cache[:, :, :end_idx, :], + new_value_cache[:, :, :end_idx, :], + dropout_p=0.0, + is_causal=is_causal, + ) + + return attn_output, new_key_cache, new_value_cache + + def forward( + self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True + ): + concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) + concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2) + new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2) + new_value_cache = torch.cat( + (concat_values, value_cache[:, :, end_idx:, :]), dim=2 + ) + attn_output = torch._C._nn.scaled_dot_product_attention( + q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal + ) + + return attn_output, new_key_cache, new_value_cache + + +def eager_sdpa( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +) -> torch.Tensor: + """ + Eager implementation of SDPA + """ + import math + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +def print_diff(tensor1, tensor2, prefix=""): + """ + Print the diff between two tensors + """ + print( + f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}" + ) + + +def test_no_cache_model_with_torch_tensorrt(args): + """ + Test the no cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + # q = torch.randn(1, 32, 6, 64).cuda() + # k = torch.randn(1, 32, 6, 64).cuda() + # v = torch.randn(1, 32, 6, 64).cuda() + q = torch.load("query.pt") + k = torch.load("key.pt") + v = torch.load("value.pt") + out_no_cache = model_no_cache(q, k, v) + out_eager = eager_sdpa(q, k, v, is_causal=True) + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + # Export the model + exported_program = torch.export.export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2: q_seq_len}, {2: q_seq_len}, {2: q_seq_len}), + strict=False, + ) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + debug=args.debug, + min_block_size=1, + ) + out_trt = trt_model(q, k, v) + + print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager") + print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt") + print_diff(out_eager, out_trt, "out_eager vs out_trt") + breakpoint() + + +def test_static_cache_model(args): + """ + Test the static cache model + """ + with torch.inference_mode(): + model_no_cache = ModelNoCache().eval().cuda() + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Test Prefill + start_idx = 0 + end_idx = 2048 + out_no_cache = model_no_cache(q, k, v) + out_static_cache, new_key_cache, new_value_cache = model_static_cache( + q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True + ) + assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_static_cache, new_key_cache, new_value_cache = model_static_cache( + q_curr, + k_curr, + v_curr, + new_key_cache, + new_value_cache, + start_idx, + end_idx, + is_causal=False, + ) + + assert torch.allclose( + out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL + ) + q = q_full + k = k_full + v = v_full + print("============== test_static_cache passed ==============") + + +def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args): + """ + Transform the graph module by adding key and value cache to the graph + """ + gm = exported_program.module() + # Post lower the model + settings = torch_tensorrt.dynamo.conversion.CompilationSettings( + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + exported_program = pre_export_lowering(exported_program, settings) + exported_program = exported_program.run_decompositions(get_decompositions(False)) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + return gm + + +def test_static_cache_lowering(args): + """ + Test static cache lowering pass applied to the model with no cache and run the graph module + and compare the output with the model with no cache + """ + import static_cache2 + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}), + strict=False, + ) + + gm = transform_gm_with_kv_cache(exported_program, args) + + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + out_no_cache = model_no_cache(q, k, v) + out_pyt_cache, key_cache, value_cache = gm( + q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal + ) + assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL) + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + is_causal = False + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_pyt_static_cache, key_cache, value_cache = gm( + q_curr, + k_curr, + v_curr, + key_cache, + value_cache, + start_idx, + end_idx, + is_causal, + ) + assert torch.allclose( + out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL + ) + q = q_full + k = k_full + v = v_full + + print("============== test_static_cache_lowering passed ==============") + + +def test_static_cache_export(args): + """ + Test the static cache model export + """ + model_static_cache = StaticCacheModel().eval().cuda() + q = torch.randn(1, 32, 2048, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + # Test Prefill + start_idx = 0 + end_idx = 2048 + is_causal = True + # Export the model + seq_len = torch.export.Dim("seq_len", min=2, max=2048) + seq_len_dyn_dim = {2: seq_len} + exported_program = export( + model_static_cache, + args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal), + dynamic_shapes=( + seq_len_dyn_dim, + seq_len_dyn_dim, + seq_len_dyn_dim, + None, + None, + torch.export.Dim.DYNAMIC, + torch.export.Dim.DYNAMIC, + None, + ), + strict=False, + ) + + +def test_static_cache_with_torch_tensorrt(args): + """ + Test the static cache model with torch_tensorrt + """ + import static_cache_v2 + + model_no_cache = ModelNoCache().eval().cuda() + q = torch.randn(1, 32, 2, 64).cuda() + k = torch.randn(1, 32, 2048, 64).cuda() + v = torch.randn(1, 32, 2048, 64).cuda() + key_cache = torch.zeros(1, 32, 2176, 64).cuda() + value_cache = torch.zeros(1, 32, 2176, 64).cuda() + + # Export the model + q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176) + kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176) + exported_program = export( + model_no_cache, + args=(q, k, v), + dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}), + strict=False, + ) + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[q, k, v], + enabled_precisions={torch.float32}, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + min_block_size=1, + ) + + start_idx = 0 + end_idx = 2048 + is_causal = True + q = torch.randn(1, 32, 2048, 64).cuda() + # out_eager = eager_sdpa(q, k, v, is_causal=is_causal) + out_no_cache = model_no_cache(q, k, v) + out_trt, trt_key_cache, trt_value_cache = trt_model( + q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal + ) + + assert torch.allclose( + out_no_cache, out_trt, atol=ATOL, rtol=RTOL + ), "Prefill TRT logits don't match" + assert torch.allclose( + trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL + ), "Prefill TRT key cache don't match" + assert torch.allclose( + trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL + ), "Prefill TRT value cache don't match" + + # Test Generate + for start_idx in range(2048, 2176): + end_idx = start_idx + 1 + q_curr = torch.randn(1, 32, 1, 64).cuda() + k_curr = torch.randn(1, 32, 1, 64).cuda() + v_curr = torch.randn(1, 32, 1, 64).cuda() + # Concatenate the current query, key, and value with the previous ones + q_full = torch.cat((q, q_curr), dim=2) + k_full = torch.cat((k, k_curr), dim=2) + v_full = torch.cat((v, v_curr), dim=2) + is_causal = True + out_no_cache = model_no_cache(q_full, k_full, v_full) + out_trt, trt_key_cache, trt_value_cache = trt_model( + q_curr, + k_curr, + v_curr, + trt_key_cache, + trt_value_cache, + start_idx, + end_idx, + is_causal, + ) + # breakpoint() + # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}") + # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}") + # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}") + assert torch.allclose( + out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL + ), f"Generate TRT logits don't match for idx {start_idx}" + assert torch.allclose( + trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL + ), f"Generate TRT key cache don't match for idx {start_idx}" + assert torch.allclose( + trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL + ), f"Generate TRT value cache don't match for idx {start_idx}" + q = q_full + k = k_full + v = v_full + + print("============== test_static_cache_with_torch_tensorrt passed ==============") + + +def main(): + arg_parser = argparse.ArgumentParser( + description="Run test cases for llama attention and decoder" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + args = arg_parser.parse_args() + with torch.inference_mode(): + # test_no_cache_model_with_torch_tensorrt(args) + # test_static_cache_model(args) + # test_static_cache_lowering(args) + test_static_cache_with_torch_tensorrt(args) + + +if __name__ == "__main__": + main() diff --git a/tools/llm/utils.py b/tools/llm/utils.py new file mode 100644 index 0000000000..5ccb9d0e55 --- /dev/null +++ b/tools/llm/utils.py @@ -0,0 +1,244 @@ +import copy +import timeit + +import numpy as np +import torch +from transformers import StoppingCriteriaList +from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, +) + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + args=(inputs,), + kwargs={"position_ids": position_ids}, + dynamic_shapes=({1: seq_len}, {1: seq_len}), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + + +def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-2] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append( + torch.zeros( + input.meta["val"].shape, + dtype=input.meta["val"].dtype, + device=torch.device("cuda:0"), + ) + ) + + return tuple(zeroed_kv_cache_inputs) + + +def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule): + """ + Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache. + + This function identifies placeholder nodes in the graph that represent KV cache tensors, + and creates zeroed tensors with the same shape, dtype, and device as the original placeholders. + + Args: + model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders + + Returns: + tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph + """ + # placeholder nodes are expected to be in the following order: + # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx + placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"] + # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors. + kv_cache_inputs = placeholder_nodes[2:-1] + zeroed_kv_cache_inputs = [] + for input in kv_cache_inputs: + zeroed_kv_cache_inputs.append( + torch.zeros( + input.meta["val"].shape, + dtype=input.meta["val"].dtype, + device=torch.device("cuda:0"), + ) + ) + + return tuple(zeroed_kv_cache_inputs) + + +def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + stopping_criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=max_output_seq_length), + EosTokenCriteria(eos_token_id=eos_token_id), + ] + ) + isl = input_seq.shape[1] + osl = max_output_seq_length - isl + + num_tokens_generated = 0 + while num_tokens_generated < osl: + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + outputs = model(input_seq, position_ids=position_ids) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + num_tokens_generated += 1 + # TODO: Handle batch in this check + if not benchmark and stopping_criteria(input_seq, logits).item(): + break + + return input_seq + + +def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with static KV cache. + """ + start_idx = 0 + end_idx = input_seq.shape[1] + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL + num_tokens_generated = 0 + kv_cache = get_zeroed_static_cache_inputs(model) + while end_idx < max_output_seq_length: + position_ids = ( + torch.tensor([[start_idx]], dtype=torch.int64).cuda() + if input_seq.shape[1] == 1 + else position_ids + ) + input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + start_idx = end_idx + end_idx = start_idx + 1 + return output_seq + + +def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id): + """ + Greedy decoding of the model with dynamic KV cache. + """ + position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda() + output_seq = input_seq.clone() + num_output_tokens = max_output_seq_length - input_seq.shape[1] + num_tokens_generated = 0 + kv_cache = get_zeroed_dynamic_cache_inputs(model) + last_position_id = position_ids[-1, -1].item() + breakpoint() + while num_tokens_generated < num_output_tokens: + is_generate = False if input_seq.shape[1] > 1 else True + position_ids = ( + torch.tensor([[last_position_id + 1]], dtype=torch.int64).cuda() + if input_seq.shape[1] == 1 + else position_ids + ) + input_signature = (input_seq, position_ids, *kv_cache, is_generate) + logits_keys_values = model(*input_signature) + num_tokens_generated += 1 + logits = logits_keys_values[0] + kv_cache = logits_keys_values[1:] + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + output_seq = torch.cat([output_seq, next_tokens], dim=-1) + input_seq = next_tokens + last_position_id += 1 + return output_seq + + +def time_generate( + generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10 +): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + _ = generate_fn(model, inputs, output_seq_length, eos_token_id) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings + + +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): + """ + Records different timing stats and adds it to the result + """ + times = np.array(timings) + speeds = batch_size / times + time_mean = np.mean(times).item() + time_med = np.median(times).item() + time_99th = np.percentile(times, 99).item() + time_std = np.std(times, ddof=0).item() + speed_mean = np.mean(speeds).item() + speed_med = np.median(speeds).item() + + stats = { + "Backend": backend, + "Precision": precision, + "Batch size": batch_size, + "Median(FPS)": speed_med, + "Mean(FPS)": speed_mean, + "Median-Latency(ms)": time_med * 1000, + "Mean-Latency(ms)": time_mean * 1000, + "Latency-StdDev(ms)": time_std * 1000, + "Compile Time(s)": compile_time_s, + } + return stats