@@ -149,8 +149,12 @@ def get_local_global_attention_mask(
149
149
cache_len = attention_mask .shape [- 1 ],
150
150
sliding_window_size = sliding_window_size ,
151
151
)
152
- # Combine masks using logical AND (min in this case).
153
- combined_mask = torch .min (attention_mask , sliding_mask )
152
+ # Expand sliding_mask to match attention_mask's dimensions
153
+ # (e.g., [B, 1, seq_len, cache_len]).
154
+ # Assuming the head dimension is dim 1 for attention_mask.
155
+ expanded_sliding_mask = sliding_mask .unsqueeze (1 )
156
+ # Combine masks using logical AND (min ensures -inf propagates).
157
+ combined_mask = torch .min (attention_mask , expanded_sliding_mask )
154
158
return combined_mask
155
159
return attention_mask
156
160
@@ -161,9 +165,9 @@ def create_sliding_mask(
161
165
sliding_window_size : int ,
162
166
) -> torch .Tensor :
163
167
"""Creates mask for sliding window attention (PyTorch)."""
164
- cache_positions = torch .tensor (
165
- [ i for i in range ( cache_len )], dtype = torch . int32
166
- )
168
+ # Use torch.arange to create a tensor with a range of integers in a
169
+ # Dynamo-friendly way.
170
+ cache_positions = torch . arange ( cache_len , dtype = torch . int32 )
167
171
cache_positions = cache_positions .view (1 , 1 , - 1 ) # [1, 1, cache_len]
168
172
segment_pos_expanded = segment_pos .clone ().unsqueeze (- 1 ) # [B, seq_len, 1]
169
173
0 commit comments