Skip to content

Commit cc5b3ed

Browse files
Arm backend: test_pipeline improvements (#8644)
- Add OpNotSupportedPipeline for checking that ops are not delegated properly. - Make use_to_edge_transform_and_lower default to true since this is the recommended API. - Rename TestPassPipeline -> PassPipeline to avoid warnings in pytest log, and make exir_ops optional as its not used then. - Allow to add the first non unique stage to a pipeline w/o suffix (E.g. run_method_and_compare_outputs will rarely be used twice even though it is theoretically possible, so we don't want to refer to it as run_method_and_compare_outputs.0 if not necessary). - Add custom_path option to all pipelines for easily dumping artifacts. - Typing and documentation fixes. Signed-off-by: Adrian Lundell <[email protected]>
1 parent b6bd89d commit cc5b3ed

9 files changed

+175
-56
lines changed

backends/arm/test/misc/test_partition_decomposed_quantized_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def test_softplus_tosa_BI(test_data: input_t1):
6060
pipeline.pop_stage("check_not.exir")
6161
# check that all ops in exir_op except add are rejected
6262
pipeline.add_stage_after(
63-
"partition", pipeline.tester.check, exir_op[1:], suffix="exir_post_partition"
63+
"to_edge_transform_and_lower",
64+
pipeline.tester.check,
65+
exir_op[1:],
66+
suffix="exir_post_partition",
6467
)
6568
pipeline.run()

backends/arm/test/ops/test_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def test_conv2d_tosa_BI(test_module):
370370
pipeline = TosaPipelineBI[input_t](
371371
test_module, test_module.get_inputs(), aten_op, exir_op
372372
)
373-
pipeline.change_args("run_method_and_compare_outputs.0", qtol=1)
373+
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
374374
pipeline.run()
375375

376376

backends/arm/test/passes/test_cast_int64_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
1010

11-
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
11+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1212

1313
input_t = Tuple[torch.Tensor] # Input x
1414

@@ -28,7 +28,7 @@ def test_int64_model_tosa_BI():
2828
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
2929
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
3030
}
31-
pipeline = TestPassPipeline[input_t](
31+
pipeline = PassPipeline[input_t](
3232
module,
3333
module.get_inputs(),
3434
tosa_version="TOSA-0.80+BI",

backends/arm/test/passes/test_fold_qdq_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1010
FoldAndAnnotateQParamsPass,
1111
)
12-
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
12+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1313

1414

1515
input_t = Tuple[torch.Tensor, torch.Tensor] # Input x, y
@@ -32,7 +32,7 @@ def test_fold_qdq_pass_tosa_BI():
3232
is removed from the representation.
3333
"""
3434
module = SimpleQuantizeModel()
35-
pipeline = TestPassPipeline[input_t](
35+
pipeline = PassPipeline[input_t](
3636
module,
3737
module.get_inputs(),
3838
tosa_version="TOSA-0.80+BI",

backends/arm/test/passes/test_fuse_batchnorm_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
1010
from executorch.backends.arm.test import common
11-
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
11+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1212

1313
input_t = Tuple[torch.Tensor] # Input x
1414

@@ -138,7 +138,7 @@ def forward(self, x):
138138
@common.parametrize("module", modules)
139139
def test_fuse_batchnorm_tosa_MI(module):
140140
"""Test various cases where the batchnorm should and shouldn't be fused."""
141-
pipeline = TestPassPipeline[input_t](
141+
pipeline = PassPipeline[input_t](
142142
module,
143143
module.get_inputs(),
144144
tosa_version="TOSA-0.80+MI",

backends/arm/test/passes/test_insert_table_ops_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FoldAndAnnotateQParamsPass,
1212
)
1313
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
14-
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1515

1616
input_t = Tuple[torch.Tensor] # Input x
1717

@@ -27,7 +27,7 @@ def get_inputs(self) -> input_t:
2727

2828
def test_insert_table_tosa_BI():
2929
module = Sigmoid()
30-
pipeline = TestPassPipeline[input_t](
30+
pipeline = PassPipeline[input_t](
3131
module,
3232
module.get_inputs(),
3333
tosa_version="TOSA-0.80+BI",

backends/arm/test/passes/test_meandim_to_averagepool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ConvertMeanDimToAveragePoolPass,
1212
)
1313
from executorch.backends.arm.test import common
14-
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
14+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1515

1616

1717
input_t = Tuple[torch.Tensor, torch.Tensor] # Input x
@@ -65,7 +65,7 @@ def test_meandim_to_avgpool_tosa_BI(module):
6565
Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d
6666
for the special case where dim is [-1, -2] and keepdim is True.
6767
"""
68-
pipeline = TestPassPipeline[input_t](
68+
pipeline = PassPipeline[input_t](
6969
module,
7070
module.get_inputs(),
7171
tosa_version="TOSA-0.80+BI",

backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
UnsqueezeBeforeRepeatPass,
1111
)
1212
from executorch.backends.arm.test import common
13-
from executorch.backends.arm.test.tester.test_pipeline import TestPassPipeline
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1414

1515
input_t = Tuple[
1616
torch.Tensor, Dict[str, int], list[str]
@@ -47,7 +47,7 @@ def test_unsqueeze_before_repeat_tosa_MI(test_data):
4747
"""
4848
module = Repeat()
4949
data, ops_after_pass, ops_not_after_pass = test_data
50-
pipeline = TestPassPipeline(
50+
pipeline = PassPipeline(
5151
module,
5252
data,
5353
tosa_version="TOSA-0.80+MI",

0 commit comments

Comments
 (0)