Skip to content

make operator name consistent before and after serde #12531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/gasoonjia/24/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Expand All @@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Expand All @@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node):
)
elif isinstance(node.target, torch._ops.OpOverload):
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
Expand All @@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node):
)
elif isinstance(node.target, torch._ops.HigherOrderOperator):
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
outputs=self.serialize_hoo_outputs(node),
Expand Down Expand Up @@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:

def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
name = serialized_node.outputs[0].value.as_name
name = serialized_node.name
args = self.deserialize_sym_op_inputs(serialized_node.inputs)

fx_node = self.graph.create_node("call_function", target, args, {}, name)
Expand All @@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
# have names that are consistent with serialized.
#
# HOPs don't have schema yet, just check the output lengths and as_tensor attribute
name = (
serialized_node.outputs[0].as_tensor.name
if len(serialized_node.outputs) == 1
and hasattr(serialized_node.outputs[0], "as_tensor")
else None
)
name = serialized_node.name
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
)
Expand All @@ -1687,11 +1686,9 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
# For convenience: if this node returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
name = (
serialized_node.outputs[0].as_tensor.name
if _is_single_tensor_return(target)
else None # FX will generate a name for us.
)

name = serialized_node.name

args, kwargs = self.deserialize_inputs(target, serialized_node)
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
Expand Down
1 change: 1 addition & 0 deletions exir/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class NamedArgument:

@dataclass
class Node:
name: str
target: str
inputs: List[NamedArgument]
outputs: List[Argument]
Expand Down
3 changes: 3 additions & 0 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:

if node.target is memory.alloc:
ex_node = schema.Node(
name=node.name,
target="memory.alloc",
inputs=self.serialize_alloc_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
Expand All @@ -99,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
elif isinstance(node.target, EdgeOpOverload):
assert node.target._op is not None
ex_node = schema.Node(
name=node.name,
target=self.serialize_operator(node.target),
# pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
# `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
Expand All @@ -111,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
return
elif node.target is delegate.executorch_call_delegate:
ex_node = schema.Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_call_delegate_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
Expand Down
30 changes: 28 additions & 2 deletions exir/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def check_ep(
ep1: TorchExportedProgram,
ep2: TorchExportedProgram,
inputs: Tuple[exir.Value, ...],
compare_closeness: bool = False,
) -> None:
"""
Checks if two graphs are equivalent
Expand All @@ -55,15 +56,40 @@ def check_ep(
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
self.assertTrue(torch.allclose(orig, loaded))

if compare_closeness:
self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes))
for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes):
self.assertEqual(node_a.target, node_b.target)
self.assertEqual(node_a.name, node_b.name)
self.assertEqual(node_a.type, node_b.type)
self.assertEqual(node_a.op, node_b.op)
if node_a.op != "call_function":
continue

self.assertEqual(
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
)
from_node_a = node_a.meta.get("from_node")
from_node_b = node_b.meta.get("from_node")

if from_node_a is None:
self.assertIsNone(from_node_b)
else:
self.assertIsNotNone(from_node_b)
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
self.assertEqual(
node_source_a.to_dict(), node_source_b.to_dict()
)

# pyre-ignore
def check_serde(self, m, inputs, check_executorch=True) -> None:
aten = export(m, inputs, strict=True)
aten_new = deserialize(serialize(aten))
self.check_ep(aten, aten_new, inputs)
self.check_ep(aten, aten_new, inputs, compare_closeness=True)

edge = to_edge(aten)
edge_new = deserialize(serialize(edge.exported_program()))
self.check_ep(edge.exported_program(), edge_new, inputs)
self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True)

buffer = io.BytesIO()
exir.save(edge.exported_program(), buffer)
Expand Down
Loading