diff --git a/backends/arm/test/misc/test_dump_tosa.py b/backends/arm/test/misc/test_dump_tosa.py new file mode 100644 index 00000000000..f45a3aa2213 --- /dev/null +++ b/backends/arm/test/misc/test_dump_tosa.py @@ -0,0 +1,69 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile + +from typing import Tuple + +import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder + +from executorch.backends.arm.test import common +from executorch.backends.arm.tosa_partitioner import TOSAPartitioner +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir import to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig + +input_t1 = Tuple[torch.Tensor] + + +class Linear(torch.nn.Module): + inputs = { + "randn": (torch.randn(2, 8),), + } + + def __init__(self): + super().__init__() + in_features = 8 + out_features = 16 + self.weight = torch.nn.Parameter(torch.randn(out_features, in_features)) + self.bias = torch.nn.Parameter(torch.randn(out_features)) + + def forward(self, x): + y = torch.matmul(x, self.weight.t()) + return torch.add(y, self.bias) + + +def _file_non_empty(path: str) -> bool: + return os.path.exists(path) and os.path.getsize(path) > 0 + + +@common.parametrize("test_data", Linear.inputs) +def test_MI_dump_delegate_data(test_data: input_t1): + + m = Linear().eval() + ep = torch.export.export(m, test_data, strict=True) + + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + partitioner = TOSAPartitioner( + ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=tosa_spec).build() + ) + edge = to_edge(ep).to_backend(partitioner) + + config = ExecutorchBackendConfig(extract_delegate_segments=True) + exec_pm = edge.to_executorch(config) + + tmp_dir = tempfile.mkdtemp() + out_file = os.path.join(tmp_dir, "delegate_MI.tosa") + prefix, ext = os.path.splitext(out_file) + # Ensure extension is provided, if not then use default + if not ext: + ext = ".tosa" + exec_pm.dump_delegate_data(path=prefix, extension=ext) + print(f"delegate file ⇒ {prefix + ext}") + + assert os.path.exists(prefix + ext), f"File {prefix + ext} not created" + assert os.path.getsize(prefix + ext) > 0, f"File {prefix + ext} is empty" diff --git a/exir/program/_program.py b/exir/program/_program.py index 8ef02f233ac..f42e9760975 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1726,6 +1726,75 @@ def dump_executorch_program( else: print_program(self._emitter_output.program, out=out) + def dump_delegate_data( # noqa: C901 + self, + path: str, + extension: str, + delegate_id: Optional[str] = None, + ) -> None: + """ + Dumps the delegate blob out of backend_delegate_data to . + Must have been created with extract_delegate_segments=True. + If no delegate_id is given and exactly one exists, it will be used. + """ + if not extension.startswith("."): + extension = "." + extension + out_dir = os.path.dirname(path) + if out_dir and not os.path.isdir(out_dir): + os.makedirs(out_dir, exist_ok=True) + + eo = self._emitter_output + blobs = eo.program.backend_delegate_data + if not blobs: + raise RuntimeError("No delegate data was produced for this model") + + mapping = getattr(self, "delegate_map", None) + if not blobs or not mapping: + raise RuntimeError("No delegate segments available in this program.") + + # Create a list of (method, index, name) + entries: List[Tuple[str, int, str]] = [] + for method_name, id_map in mapping.items(): + for idx, info in id_map.items(): + name = info.get("name", str(idx)) + entries.append((method_name, idx, name)) + + # Choose which delegate to dump + if delegate_id is None: + if len(entries) == 1: + method, idx, delegate_id = entries[0] + else: + names = [n for (_m, _i, n) in entries] + raise ValueError(f"Multiple delegete IDs found: {names}.") + + # Find the selected delegate + matches = [(m, i, n) for (m, i, n) in entries if n == delegate_id] + if not matches: + raise ValueError(f"Delegate data for id {delegate_id} not found.") + if len(matches) > 1: + methods = [m for (m, _, _) in matches] + raise ValueError( + f"Delegate ID {delegate_id} ambiguous across methods {methods}." + ) + method, idx, _ = matches[0] + + if len(entries) > 1: + filename = f"{path}_{method}{extension}" + else: + filename = f"{path}{extension}" + + blob = blobs[idx] + if hasattr(blob, "data"): + data = blob.data + elif isinstance(blob, (bytes, bytearray)): + data = blob + else: + # cord data - convert to bytes + data = bytes(blob) + + with open(filename, "wb") as f: + f.write(data) + @property def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: return self._emitter_output.debug_handle_map diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index d5de78909ce..6b23a7c048f 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,6 +8,8 @@ # pyre-unsafe import copy +import os +import tempfile import unittest from typing import Any, Dict @@ -652,6 +655,48 @@ def get_num_nondecomposed_ops(self, ep, partitioner): num_non_decomposed_aten_ops += 1 return num_non_decomposed_aten_ops + def test_dump_delegate_data(self): + manager = to_edge(get_exported_programs(), get_config_methods()).to_executorch() + + delegate_segments = getattr(manager._emitter_output, "delegate_segments", None) + if not delegate_segments: + self.skipTest("No delegate segments present in this configuration.") + + method, segments = next(iter(delegate_segments.items())) + delegate_id = next(iter(segments.keys())) + + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "delegate_dump") + ext = ".bin" + manager.dump_delegate_data(file_path, ext, delegate_id) + output_file = file_path + ext + self.assertTrue( + os.path.isfile(output_file), f"Expected {output_file} to exist." + ) + with open(output_file, "rb") as f: + blob = f.read() + self.assertGreater(len(blob), 0, "Delegate dump file should not be empty.") + + if len(delegate_segments) == 1 and len(segments) == 1: + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "delegate_dump2") + manager.dump_delegate_data(file_path, ext) + self.assertTrue(os.path.isfile(file_path + ext)) + + def test_dump_delegate_data_invalid(self): + manager = to_edge(get_exported_programs(), get_config_methods()).to_executorch() + + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "invalid_delegate") + ext = ".bin" + + with self.assertRaises(RuntimeError): + manager.dump_delegate_data( + file_path, + ext, + delegate_id="not_a_real_delegate", + ) + def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module): # This is the pre-dispatch export that we will be switching to primarily # in the near future. The input to to_edge_transform_and_lower needs to