Skip to content

Commit 0dc3a7e

Browse files
committed
chore: updates
1 parent f539b55 commit 0dc3a7e

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

examples/dynamo/lower_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def replace_variants_of_sdpa(
4444

4545
if attn_mask is not None:
4646
logger.warning(f"We do not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration.")
47-
breakpoint()
47+
4848
modified_input_args = (query, key, value, None, dropout_p, is_causal)
4949

5050
# Create a new node with torch.nn.functional.scaled_dot_product_attention

examples/dynamo/static_cache2.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ def get_static_tensor(tensor: torch.Tensor):
123123
start_idx_input.meta["val"] = start_idx_unbacked_symint
124124
end_idx_input.meta["val"] = end_idx_unbacked_symint
125125

126-
return kv_inputs, start_idx_input, end_idx_input
126+
# Add is_causal as input
127+
is_causal_input = add_graph_input(gm, "is_causal", True)
128+
is_causal_input.meta["val"] = torch.tensor(True)
129+
130+
return kv_inputs, start_idx_input, end_idx_input, is_causal_input
127131

128132
def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input):
129133
"""
@@ -212,7 +216,7 @@ def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_nod
212216

213217
return concat_keys_or_values, new_incoming_keys_or_values
214218

215-
def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node):
219+
def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]], start_idx_input: Node, end_idx_input: Node, is_causal_input: Node):
216220
"""
217221
Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
218222
concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
@@ -239,7 +243,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
239243
kv_cache_for_graph.extend([new_incoming_key_cache_node, new_incoming_value_cache_node])
240244

241245
# Update the SDPA node arguments with current key and value nodes
242-
sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + sdpa_node.args[3:]
246+
sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (None, is_causal_input) # + sdpa_node.args[3:]
243247

244248
kv_cache_for_graph.extend([k_node, v_node])
245249
return gm, kv_cache_for_graph
@@ -252,11 +256,11 @@ def insert_kv_cache(
252256
"""Insert KV cache ops in the graph"""
253257
"""Perform insertion of kv-caches and attention kernel."""
254258
# Add static key and value as inputs to the graph
255-
kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
259+
kv_inputs, start_idx_input, end_idx_input, is_causal_input = add_kv_cache_inputs(gm, fixed_kv=True)
256260

257261
# Build and update the KV cache using computed KV inputs for current token and
258262
# incoming keys and values from previous tokens (which were added as inputs)
259-
gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input)
263+
gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(gm, kv_inputs, start_idx_input, end_idx_input, is_causal_input)
260264

261265
# Call the function to add KV as outputs
262266
logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)

0 commit comments

Comments
 (0)