@@ -123,7 +123,11 @@ def get_static_tensor(tensor: torch.Tensor):
123
123
start_idx_input .meta ["val" ] = start_idx_unbacked_symint
124
124
end_idx_input .meta ["val" ] = end_idx_unbacked_symint
125
125
126
- return kv_inputs , start_idx_input , end_idx_input
126
+ # Add is_causal as input
127
+ is_causal_input = add_graph_input (gm , "is_causal" , True )
128
+ is_causal_input .meta ["val" ] = torch .tensor (True )
129
+
130
+ return kv_inputs , start_idx_input , end_idx_input , is_causal_input
127
131
128
132
def create_kv_cache_update_nodes (gm , sdpa_node , current_kv_node , incoming_kv_node , start_idx_input , end_idx_input ):
129
133
"""
@@ -212,7 +216,7 @@ def create_kv_cache_update_nodes(gm, sdpa_node, current_kv_node, incoming_kv_nod
212
216
213
217
return concat_keys_or_values , new_incoming_keys_or_values
214
218
215
- def insert_kv_slicing_before_sdpa (gm , incoming_keys_values : List [Tuple [torch .Tensor , torch .Tensor ]], start_idx_input : Node , end_idx_input : Node ):
219
+ def insert_kv_slicing_before_sdpa (gm , incoming_keys_values : List [Tuple [torch .Tensor , torch .Tensor ]], start_idx_input : Node , end_idx_input : Node , is_causal_input : Node ):
216
220
"""
217
221
Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
218
222
concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
@@ -239,7 +243,7 @@ def insert_kv_slicing_before_sdpa(gm, incoming_keys_values: List[Tuple[torch.Ten
239
243
kv_cache_for_graph .extend ([new_incoming_key_cache_node , new_incoming_value_cache_node ])
240
244
241
245
# Update the SDPA node arguments with current key and value nodes
242
- sdpa_node .args = (q_node , new_current_key_node , new_current_value_node ) + sdpa_node .args [3 :]
246
+ sdpa_node .args = (q_node , new_current_key_node , new_current_value_node ) + ( None , is_causal_input ) # + sdpa_node.args[3:]
243
247
244
248
kv_cache_for_graph .extend ([k_node , v_node ])
245
249
return gm , kv_cache_for_graph
@@ -252,11 +256,11 @@ def insert_kv_cache(
252
256
"""Insert KV cache ops in the graph"""
253
257
"""Perform insertion of kv-caches and attention kernel."""
254
258
# Add static key and value as inputs to the graph
255
- kv_inputs , start_idx_input , end_idx_input = add_kv_cache_inputs (gm , fixed_kv = True )
259
+ kv_inputs , start_idx_input , end_idx_input , is_causal_input = add_kv_cache_inputs (gm , fixed_kv = True )
256
260
257
261
# Build and update the KV cache using computed KV inputs for current token and
258
262
# incoming keys and values from previous tokens (which were added as inputs)
259
- gm , kv_cache_for_graph = insert_kv_slicing_before_sdpa (gm , kv_inputs , start_idx_input , end_idx_input )
263
+ gm , kv_cache_for_graph = insert_kv_slicing_before_sdpa (gm , kv_inputs , start_idx_input , end_idx_input , is_causal_input )
260
264
261
265
# Call the function to add KV as outputs
262
266
logits_keys_values = add_kv_as_outputs (gm , kv_cache_for_graph )
0 commit comments