-
Notifications
You must be signed in to change notification settings - Fork 615
Arm backend: Add dump_delegate_data function #12334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wwwind
wants to merge
3
commits into
pytorch:main
Choose a base branch
from
wwwind:dump_delegate_data
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+183
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import tempfile | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder | ||
|
||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner | ||
from executorch.backends.arm.tosa_specification import TosaSpecification | ||
from executorch.exir import to_edge | ||
from executorch.exir.capture._config import ExecutorchBackendConfig | ||
|
||
input_t1 = Tuple[torch.Tensor] | ||
|
||
|
||
class Linear(torch.nn.Module): | ||
inputs = { | ||
"randn": (torch.randn(2, 8),), | ||
} | ||
|
||
def __init__(self): | ||
super().__init__() | ||
in_features = 8 | ||
out_features = 16 | ||
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features)) | ||
self.bias = torch.nn.Parameter(torch.randn(out_features)) | ||
|
||
def forward(self, x): | ||
y = torch.matmul(x, self.weight.t()) | ||
return torch.add(y, self.bias) | ||
|
||
|
||
def _file_non_empty(path: str) -> bool: | ||
return os.path.exists(path) and os.path.getsize(path) > 0 | ||
|
||
|
||
@common.parametrize("test_data", Linear.inputs) | ||
def test_MI_dump_delegate_data(test_data: input_t1): | ||
|
||
m = Linear().eval() | ||
ep = torch.export.export(m, test_data, strict=True) | ||
|
||
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") | ||
partitioner = TOSAPartitioner( | ||
ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=tosa_spec).build() | ||
) | ||
edge = to_edge(ep).to_backend(partitioner) | ||
|
||
config = ExecutorchBackendConfig(extract_delegate_segments=True) | ||
exec_pm = edge.to_executorch(config) | ||
|
||
tmp_dir = tempfile.mkdtemp() | ||
out_file = os.path.join(tmp_dir, "delegate_MI.tosa") | ||
prefix, ext = os.path.splitext(out_file) | ||
# Ensure extension is provided, if not then use default | ||
if not ext: | ||
ext = ".tosa" | ||
exec_pm.dump_delegate_data(path=prefix, extension=ext) | ||
print(f"delegate file ⇒ {prefix + ext}") | ||
|
||
assert os.path.exists(prefix + ext), f"File {prefix + ext} not created" | ||
assert os.path.getsize(prefix + ext) > 0, f"File {prefix + ext} is empty" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1726,6 +1726,75 @@ def dump_executorch_program( | |
else: | ||
print_program(self._emitter_output.program, out=out) | ||
|
||
def dump_delegate_data( # noqa: C901 | ||
self, | ||
path: str, | ||
extension: str, | ||
delegate_id: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Dumps the delegate blob out of backend_delegate_data to <path><extension>. | ||
Must have been created with extract_delegate_segments=True. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you not find the blob if its embedded in the flatbuffer section? |
||
If no delegate_id is given and exactly one exists, it will be used. | ||
""" | ||
if not extension.startswith("."): | ||
extension = "." + extension | ||
out_dir = os.path.dirname(path) | ||
if out_dir and not os.path.isdir(out_dir): | ||
os.makedirs(out_dir, exist_ok=True) | ||
|
||
eo = self._emitter_output | ||
blobs = eo.program.backend_delegate_data | ||
if not blobs: | ||
raise RuntimeError("No delegate data was produced for this model") | ||
|
||
mapping = getattr(self, "delegate_map", None) | ||
if not blobs or not mapping: | ||
raise RuntimeError("No delegate segments available in this program.") | ||
|
||
# Create a list of (method, index, name) | ||
entries: List[Tuple[str, int, str]] = [] | ||
for method_name, id_map in mapping.items(): | ||
for idx, info in id_map.items(): | ||
name = info.get("name", str(idx)) | ||
entries.append((method_name, idx, name)) | ||
|
||
# Choose which delegate to dump | ||
if delegate_id is None: | ||
if len(entries) == 1: | ||
method, idx, delegate_id = entries[0] | ||
else: | ||
names = [n for (_m, _i, n) in entries] | ||
raise ValueError(f"Multiple delegete IDs found: {names}.") | ||
|
||
# Find the selected delegate | ||
matches = [(m, i, n) for (m, i, n) in entries if n == delegate_id] | ||
if not matches: | ||
raise ValueError(f"Delegate data for id {delegate_id} not found.") | ||
if len(matches) > 1: | ||
methods = [m for (m, _, _) in matches] | ||
raise ValueError( | ||
f"Delegate ID {delegate_id} ambiguous across methods {methods}." | ||
) | ||
method, idx, _ = matches[0] | ||
|
||
if len(entries) > 1: | ||
filename = f"{path}_{method}{extension}" | ||
else: | ||
filename = f"{path}{extension}" | ||
|
||
blob = blobs[idx] | ||
if hasattr(blob, "data"): | ||
data = blob.data | ||
elif isinstance(blob, (bytes, bytearray)): | ||
data = blob | ||
else: | ||
# cord data - convert to bytes | ||
data = bytes(blob) | ||
|
||
with open(filename, "wb") as f: | ||
f.write(data) | ||
|
||
@property | ||
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: | ||
return self._emitter_output.debug_handle_map | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you take a TextIO instead? Like Dump_et_program above
edit: Or just any stream really since the data isnt text