Skip to content

Commit c02dfd6

Browse files
Arm backend: Replace asserts with error handling in upsample operators (#10577)
Update `op_upsample_bilinear2d.py` and `op_upsample_nearest2d.py` classes to replace `assert` checks with `ValueError` exceptions for improved error handling and code robustness. - Change static shape checks from assertions to conditional raises of `ValueError` for clear runtime error communication. - Replaced `assert` checks for int16 range validation of `scale_n_yx`, `scale_d_yx`, and `border_yx` with explicit value range checks that raise `ValueError` when bounds are breached. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 12ed924 commit c02dfd6

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

backends/arm/operators/op_upsample_bilinear2d.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
) -> None:
37-
assert (
38-
inputs[0].shape is not None and output.shape is not None
39-
), "Only static shapes are supported"
37+
if inputs[0].shape is None or output.shape is None:
38+
raise ValueError("Only static shapes are supported")
4039

4140
input_dtype = inputs[0].dtype
4241

@@ -55,9 +54,12 @@ def define_node(
5554
def in_int16_range(x):
5655
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
5756

58-
assert in_int16_range(scale_n_yx)
59-
assert in_int16_range(scale_d_yx)
60-
assert in_int16_range(border_yx)
57+
if not in_int16_range(scale_n_yx):
58+
raise ValueError("scale_n_yx is out of the int16 range")
59+
if not in_int16_range(scale_d_yx):
60+
raise ValueError("scale_d_yx is out of the int16 range")
61+
if not in_int16_range(border_yx):
62+
raise ValueError("border_yx is out of the int16 range")
6163

6264
attr = ts.TosaSerializerAttribute()
6365
attr.ResizeAttribute(

backends/arm/operators/op_upsample_nearest2d.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ def define_node(
3636
) -> None:
3737
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3838

39-
assert (
40-
inputs[0].shape is not None and output.shape is not None
41-
), "Only static shapes are supported"
39+
if inputs[0].shape is None or output.shape is None:
40+
raise ValueError("Only static shapes are supported")
4241

4342
# tosa_shape output is NHWC, take HW
4443
input_size_yx = torch.tensor(
@@ -55,9 +54,12 @@ def define_node(
5554
def in_int16_range(x):
5655
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
5756

58-
assert in_int16_range(scale_n_yx)
59-
assert in_int16_range(scale_d_yx)
60-
assert in_int16_range(border_yx)
57+
if not in_int16_range(scale_n_yx):
58+
raise ValueError("scale_n_yx is out of the int16 range")
59+
if not in_int16_range(scale_d_yx):
60+
raise ValueError("scale_d_yx is out of the int16 range")
61+
if not in_int16_range(border_yx):
62+
raise ValueError("border_yx is out of the int16 range")
6163

6264
attr = ts.TosaSerializerAttribute()
6365
attr.ResizeAttribute(

0 commit comments

Comments
 (0)