diff --git a/backends/apple/mps/operators/shape_ops.py b/backends/apple/mps/operators/shape_ops.py index 76c559018be..d75e42d1c45 100644 --- a/backends/apple/mps/operators/shape_ops.py +++ b/backends/apple/mps/operators/shape_ops.py @@ -243,10 +243,12 @@ def define_node( split_sizes = eval_shape(cast(torch.SymInt, node.args[1])) dim = cast(int, node.args[2]) input_shape = get_shape(get_input_node(node, 0)) + if dim < 0: + dim += len(input_shape) if dim < 0 or dim >= len(input_shape): raise RuntimeError( - f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions" + f"split_copy: dim {cast(int, node.args[2])} out of range for input tensor with {len(input_shape)} dimensions" ) mps_node = MPSNode(