diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 1acea3e086a..9522748a959 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -11,7 +11,6 @@ from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d from .convert_square_to_pow import ConvertSquareToPow -from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear from .decompose_any import DecomposeAny from .decompose_cdist import DecomposeCDist from .decompose_einsum import DecomposeEinsum @@ -48,7 +47,6 @@ ConvertBmmToMatmul, ConvertConv1dToConv2d, ConvertSquareToPow, - ConvertUpsampleBicubicWithBilinear, DecomposeAny, DecomposeCDist, DecomposeEinsum, diff --git a/backends/qualcomm/_passes/convert_upsample_bicubic2d.py b/backends/qualcomm/_passes/convert_upsample_bicubic2d.py deleted file mode 100644 index 367e9155c77..00000000000 --- a/backends/qualcomm/_passes/convert_upsample_bicubic2d.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - - -class ConvertUpsampleBicubicWithBilinear(ExportPass): - """ - Qnn does not support bicubic interpolation, so we need to convert it to bilinear. - This pass will convert bicubic interpolation to bilinear interpolation. - """ - - bicubic_op_targets = { - exir_ops.edge.aten.upsample_bicubic2d.vec, - } - upsample_bilinear_op = exir_ops.edge.aten.upsample_bilinear2d.default - - def __init__(self): - super(ConvertUpsampleBicubicWithBilinear, self).__init__() - - def call_operator(self, op, args, kwargs, meta): - if op not in self.bicubic_op_targets: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator(self.upsample_bilinear_op, args[:-1], kwargs, meta) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 19c5417f8f8..dad7f042037 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -38,6 +38,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.pixel_shuffle.default, exir_ops.edge.aten.pixel_unshuffle.default, + exir_ops.edge.aten.upsample_bicubic2d.default, + exir_ops.edge.aten.upsample_bicubic2d.vec, exir_ops.edge.aten.upsample_bilinear2d.default, exir_ops.edge.aten.upsample_bilinear2d.vec, exir_ops.edge.aten.upsample_nearest2d.default, diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index d324f6144a5..c95abc0f510 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -16,7 +16,6 @@ ConvertBmmToMatmul, ConvertConv1dToConv2d, ConvertSquareToPow, - ConvertUpsampleBicubicWithBilinear, DecomposeAny, DecomposeCDist, DecomposeEinsum, @@ -82,7 +81,6 @@ def get_capture_program_passes(): (AnnotateUnbind, True), (ConvertBmmToMatmul, True), (ConvertConv1dToConv2d, True), - (ConvertUpsampleBicubicWithBilinear, False), (DecomposeAny, True), (ExpandBroadcastTensorShape, False), (FixedLinearKeepDim, True), diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 70bb705be73..fab51a47105 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -66,7 +66,6 @@ def get_passes_dependency_for_capture_program(): AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, - ConvertUpsampleBicubicWithBilinear, DecomposeAny, DecomposeLinalgVectorNorm, ExpandBroadcastTensorShape, @@ -86,19 +85,17 @@ def get_passes_dependency_for_capture_program(): AnnotateQuantAttrs: [ RecomposePixelUnshuffle, ConvertBmmToMatmul, - ConvertUpsampleBicubicWithBilinear, RemoveRedundancy, ], AnnotateStack: [RemoveRedundancy], AnnotateUnbind: [RemoveRedundancy], ConvertBmmToMatmul: [RecomposePixelUnshuffle], - ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy], DecomposeAny: [RemoveRedundancy], DecomposeLinalgVectorNorm: [RemoveRedundancy], ExpandBroadcastTensorShape: [FoldQDQ], FixedLinearKeepDim: [FoldQDQ], FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind], - I64toI32: [ConvertUpsampleBicubicWithBilinear, RemoveRedundancy], + I64toI32: [RemoveRedundancy], LayoutTransform: [ AnnotateQuantAttrs, ConvertConv1dToConv2d, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 27faa036dd5..fff2a3b4a53 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -66,6 +66,7 @@ op_relu, op_repeat, op_reshape, + op_resize, op_rms_norm, op_rsqrt, op_scalar_tensor, @@ -155,6 +156,7 @@ op_relu, op_repeat, op_reshape, + op_resize, op_rms_norm, op_rsqrt, op_scalar_tensor, diff --git a/backends/qualcomm/builders/op_resize.py b/backends/qualcomm/builders/op_resize.py new file mode 100644 index 00000000000..d9861a6c5bb --- /dev/null +++ b/backends/qualcomm/builders/op_resize.py @@ -0,0 +1,84 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np +import torch + +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpResize, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Resize(NodeVisitor): + # Because QNN support ResizeBilinear and ResizeNearestNeighbor, only bicubic need to be handled in resize op + target = ["aten.upsample_bicubic2d.vec"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + align_corners = cast(bool, node.args[2]) + transformation_mode = np.uint32(2) if align_corners else np.uint32(1) + # This builder supports only bicubic resize. + interpolation_mode = np.uint32(2) + cubic_coeff = np.float32(-0.75) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + resize_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpResize.op_name, + ) + resize_op.AddInputTensors([input_tensor_wrapper]) + resize_op.AddOutputTensors([output_tensor_wrapper]) + + resize_op.AddScalarParam( + OpResize.param_exclude_outside, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: False}, + ) + resize_op.AddScalarParam( + OpResize.param_transformation_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: transformation_mode}, + ) + + resize_op.AddScalarParam( + OpResize.param_interpolation_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: interpolation_mode}, + ) + resize_op.AddScalarParam( + OpResize.param_cubic_coeff, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: cubic_coeff}, + ) + + return resize_op diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py index 10dfe375fe0..ab8ab9b6452 100644 --- a/backends/qualcomm/builders/op_upsample_bilinear2d.py +++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py @@ -45,23 +45,23 @@ def define_node( nodes_to_wrappers, ) - reisze_bilinear_op = PyQnnWrapper.PyQnnOpWrapper( + resize_bilinear_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpResizeBilinear.op_name, ) - reisze_bilinear_op.AddInputTensors([input_tensor_wrapper]) - reisze_bilinear_op.AddOutputTensors([output_tensor_wrapper]) + resize_bilinear_op.AddInputTensors([input_tensor_wrapper]) + resize_bilinear_op.AddOutputTensors([output_tensor_wrapper]) - reisze_bilinear_op.AddScalarParam( + resize_bilinear_op.AddScalarParam( OpResizeBilinear.param_align_corners, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: node.args[2]}, ) - reisze_bilinear_op.AddScalarParam( + resize_bilinear_op.AddScalarParam( OpResizeBilinear.param_half_pixel_centers, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: not node.args[2]}, ) - return reisze_bilinear_op + return resize_bilinear_op diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py index 4e9c4741ca2..a434880e290 100644 --- a/backends/qualcomm/builders/op_upsample_nearest2d.py +++ b/backends/qualcomm/builders/op_upsample_nearest2d.py @@ -45,23 +45,23 @@ def define_node( nodes_to_wrappers, ) - reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper( + resize_nearest_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpResizeNearestNeighbor.op_name, ) - reisze_nearest_op.AddInputTensors([input_tensor_wrapper]) - reisze_nearest_op.AddOutputTensors([output_tensor_wrapper]) + resize_nearest_op.AddInputTensors([input_tensor_wrapper]) + resize_nearest_op.AddOutputTensors([output_tensor_wrapper]) # align_corners is guaranteed to be false - reisze_nearest_op.AddScalarParam( + resize_nearest_op.AddScalarParam( OpResizeNearestNeighbor.param_align_corners, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: False}, ) - reisze_nearest_op.AddScalarParam( + resize_nearest_op.AddScalarParam( OpResizeNearestNeighbor.param_half_pixel_centers, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, {QCOM_DATA: True}, ) - return reisze_nearest_op + return resize_nearest_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index c13a126f76d..7b545e5ab2d 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -408,6 +408,16 @@ class OpReshape: op_name: str = "Reshape" +@dataclass(init=False, frozen=True) +class OpResize: + op_name: str = "Resize" + param_exclude_outside: str = "exclude_outside" + param_transformation_mode: str = "transformation_mode" + param_interpolation_mode: str = "interpolation_mode" + param_nearest_mode: str = "nearest_mode" + param_cubic_coeff: str = "cubic_coeff" + + @dataclass(init=False, frozen=True) class OpResizeBilinear: op_name: str = "ResizeBilinear" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 6326f4d1210..b427c59ce07 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,7 +13,6 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.slice_scatter.default, exir_ops.edge.aten.copy.default, - exir_ops.edge.aten.upsample_bicubic2d.vec, exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 5195cf39f33..adfb907408a 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -536,6 +536,13 @@ def annotate_upsample_bilinear2d( annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.upsample_bicubic2d.vec]) +def annotate_upsample_upsample_bicubic2d( + node: Node, quantization_config: QuantizationConfig +) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.upsample_nearest2d.vec]) def annotate_upsample_nearest2d( node: Node, quantization_config: QuantizationConfig diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 025c0bee171..869029a3867 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1311,6 +1311,23 @@ def forward(self, x): return x6 +class ResizeBicubic(torch.nn.Module): + def __init__(self, size, scale_factor, align_corners): + super().__init__() + self.align_corners = align_corners + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return torch.nn.functional.interpolate( + x, + size=self.size, + scale_factor=self.scale_factor, + mode="bicubic", + align_corners=self.align_corners, + ) + + class ResizeBilinear2D(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 081dda7187b..a24bd9302a7 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -612,6 +612,18 @@ def test_qnn_backend_instance_norm_2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bicubic(self): + modules = [ + ResizeBicubic([2, 2], None, False), # noqa: F405 + ResizeBicubic(None, [2, 2], False), # noqa: F405 + ResizeBicubic([2, 2], None, True), # noqa: F405 + ResizeBicubic(None, [2, 2], True), # noqa: F405 + ] + sample_input = (torch.randn(1, 4, 2, 2),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bilinear_2d(self): module = ResizeBilinear2D() # noqa: F405 sample_input = (torch.randn(2, 3, 4, 5),) @@ -1820,6 +1832,19 @@ def test_qnn_backend_instance_norm_2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bicubic(self): + modules = [ + ResizeBicubic([2, 2], None, False), # noqa: F405 + ResizeBicubic(None, [2, 2], False), # noqa: F405 + ResizeBicubic([2, 2], None, True), # noqa: F405 + ResizeBicubic(None, [2, 2], True), # noqa: F405 + ] + sample_input = (torch.randn(1, 4, 2, 2),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bilinear_2d(self): module = ResizeBilinear2D() # noqa: F405 sample_input = (torch.randn(2, 3, 4, 5),) @@ -3884,6 +3909,40 @@ def test_dino_v2(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 85) + def test_dit(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/dit.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 80) + self.assertGreaterEqual(msg["top_5"], 95) + def test_efficientnet(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") diff --git a/examples/qualcomm/oss_scripts/dino_v2.py b/examples/qualcomm/oss_scripts/dino_v2.py index 18b5ade8b35..db0981248e9 100644 --- a/examples/qualcomm/oss_scripts/dino_v2.py +++ b/examples/qualcomm/oss_scripts/dino_v2.py @@ -10,12 +10,10 @@ import numpy as np import torch -from executorch.backends.qualcomm._passes import ConvertUpsampleBicubicWithBilinear from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY from executorch.examples.qualcomm.utils import ( build_executorch_binary, @@ -62,7 +60,6 @@ def main(args): pte_filename = "dino_v2" instance = get_instance() passes_job = get_capture_program_passes() - passes_job[ConvertUpsampleBicubicWithBilinear][QCOM_PASS_ACTIVATE_KEY] = True build_executorch_binary( instance, sample_input, diff --git a/examples/qualcomm/oss_scripts/dit.py b/examples/qualcomm/oss_scripts/dit.py new file mode 100644 index 00000000000..1dc4cebee75 --- /dev/null +++ b/examples/qualcomm/oss_scripts/dit.py @@ -0,0 +1,161 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +from multiprocessing.connection import Client + +import numpy as np +import torch + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + make_output_dir, + make_quantizer, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) + +from torchao.quantization.pt2e import HistogramObserver +from transformers import AutoImageProcessor, AutoModelForImageClassification + + +def get_rvlcdip_dataset(data_size): + from datasets import load_dataset + + dataset = load_dataset("nielsr/rvl_cdip_10_examples_per_class", split="train") + processor = AutoImageProcessor.from_pretrained( + "microsoft/dit-base-finetuned-rvlcdip" + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + for index, data in enumerate(dataset): + if index >= data_size: + break + feature, target = ( + processor(images=data["image"].convert("RGB"), return_tensors="pt"), + data["label"], + ) + inputs.append((feature["pixel_values"],)) + targets.append(torch.tensor(target)) + input_list += f"input_{index}_0.raw\n" + + return inputs, targets, input_list + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + data_num = 160 + if args.ci: + inputs = [(torch.rand(1, 3, 224, 224),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_rvlcdip_dataset(data_num) + + module = ( + AutoModelForImageClassification.from_pretrained( + "microsoft/dit-base-finetuned-rvlcdip" + ) + .eval() + .to("cpu") + ) + + pte_filename = "dit_qnn_q8" + # Use HistogramObserver to get better performance + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_8a8w, act_observer=HistogramObserver + ) + + build_executorch_binary( + module.eval(), + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + custom_quantizer=quantizer, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " "Default ./dit", + default="./dit", + type=str, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py b/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py index ea65917dcd9..8b7c1dc3dd3 100644 --- a/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py +++ b/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py @@ -13,10 +13,7 @@ import numpy as np import torch -from executorch.backends.qualcomm._passes import ( - ConvertUpsampleBicubicWithBilinear, - ExpandBroadcastTensorShape, -) +from executorch.backends.qualcomm._passes import ExpandBroadcastTensorShape from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -246,7 +243,6 @@ def main(args): # lower to QNN passes_job = get_capture_program_passes() - passes_job[ConvertUpsampleBicubicWithBilinear][QCOM_PASS_ACTIVATE_KEY] = True passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True build_executorch_binary( model,