Skip to content

Commit 0c9d1c1

Browse files
junjiang-labcopybara-github
authored andcommitted
Update Gemma3 decoder to support dynamic shapes.
PiperOrigin-RevId: 759762038
1 parent d07a7ef commit 0c9d1c1

File tree

1 file changed

+9
-5
lines changed
  • ai_edge_torch/generative/examples/gemma3

1 file changed

+9
-5
lines changed

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,12 @@ def get_local_global_attention_mask(
149149
cache_len=attention_mask.shape[-1],
150150
sliding_window_size=sliding_window_size,
151151
)
152-
# Combine masks using logical AND (min in this case).
153-
combined_mask = torch.min(attention_mask, sliding_mask)
152+
# Expand sliding_mask to match attention_mask's dimensions
153+
# (e.g., [B, 1, seq_len, cache_len]).
154+
# Assuming the head dimension is dim 1 for attention_mask.
155+
expanded_sliding_mask = sliding_mask.unsqueeze(1)
156+
# Combine masks using logical AND (min ensures -inf propagates).
157+
combined_mask = torch.min(attention_mask, expanded_sliding_mask)
154158
return combined_mask
155159
return attention_mask
156160

@@ -161,9 +165,9 @@ def create_sliding_mask(
161165
sliding_window_size: int,
162166
) -> torch.Tensor:
163167
"""Creates mask for sliding window attention (PyTorch)."""
164-
cache_positions = torch.tensor(
165-
[i for i in range(cache_len)], dtype=torch.int32
166-
)
168+
# Use torch.arange to create a tensor with a range of integers in a
169+
# Dynamo-friendly way.
170+
cache_positions = torch.arange(cache_len, dtype=torch.int32)
167171
cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
168172
segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
169173

0 commit comments

Comments
 (0)