Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/apple/mps/operators/shape_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,12 @@ def define_node(
split_sizes = eval_shape(cast(torch.SymInt, node.args[1]))
dim = cast(int, node.args[2])
input_shape = get_shape(get_input_node(node, 0))
if dim < 0:
dim += len(input_shape)

if dim < 0 or dim >= len(input_shape):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation condition on line 249 will always fail for negative dimensions after normalization. If the original dim was less than -len(input_shape), the normalized value will still be negative, but this is the intended behavior to catch out-of-bounds dimensions. However, the current implementation normalizes first, then validates, which means valid negative dimensions are handled correctly but the error condition dim < 0 is now checking the normalized value. The validation should only check dim >= len(input_shape) after normalization, or validate the original dimension range before normalization.

Suggested change
if dim < 0 or dim >= len(input_shape):
if dim >= len(input_shape):

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The negative dimension could be outside the correct range even after normalization (e.g. if there are two dimensions and the input dimension is -3). This should probably still be an error.

raise RuntimeError(
f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions"
f"split_copy: dim {cast(int, node.args[2])} out of range for input tensor with {len(input_shape)} dimensions"
)

mps_node = MPSNode(
Expand Down