Skip to content

Commit 038e039

Browse files
committed
make operator name consistent before and after serde
Differential Revision: [D78380855](https://our.internmc.facebook.com/intern/diff/D78380855/) ghstack-source-id: 296452087 Pull Request resolved: #12531
1 parent 1666a4b commit 038e039

File tree

4 files changed

+70
-14
lines changed

4 files changed

+70
-14
lines changed

exir/serde/export_serialize.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node):
504504
assert len(node.kwargs) == 0
505505
meta_val = node.meta["val"]
506506
ex_node = Node(
507+
name=node.name,
507508
target=self.serialize_operator(node.target),
508509
inputs=self.serialize_sym_op_inputs(node.target, node.args),
509510
outputs=[
@@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node):
517518
assert len(node.kwargs) == 0
518519
meta_val = node.meta["val"]
519520
ex_node = Node(
521+
name=node.name,
520522
target=self.serialize_operator(node.target),
521523
inputs=self.serialize_sym_op_inputs(node.target, node.args),
522524
outputs=[
@@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node):
528530
)
529531
elif isinstance(node.target, torch._ops.OpOverload):
530532
ex_node = Node(
533+
name=node.name,
531534
target=self.serialize_operator(node.target),
532535
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
533536
outputs=self.serialize_outputs(node),
@@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node):
536539
)
537540
elif isinstance(node.target, torch._ops.HigherOrderOperator):
538541
ex_node = Node(
542+
name=node.name,
539543
target=self.serialize_operator(node.target),
540544
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
541545
outputs=self.serialize_hoo_outputs(node),
@@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
16581662

16591663
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16601664
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
1661-
name = serialized_node.outputs[0].value.as_name
1665+
name = serialized_node.name
16621666
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
16631667

16641668
fx_node = self.graph.create_node("call_function", target, args, {}, name)
@@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16711675
# have names that are consistent with serialized.
16721676
#
16731677
# HOPs don't have schema yet, just check the output lengths and as_tensor attribute
1674-
name = (
1675-
serialized_node.outputs[0].as_tensor.name
1676-
if len(serialized_node.outputs) == 1
1677-
and hasattr(serialized_node.outputs[0], "as_tensor")
1678-
else None
1679-
)
1678+
name = serialized_node.name
16801679
fx_node = self.graph.create_node(
16811680
"call_function", target, args, kwargs, name
16821681
)
@@ -1687,16 +1686,30 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16871686
# For convenience: if this node returns a single tensor, name the
16881687
# newly-created node after it. This ensures that these tensor values
16891688
# have names that are consistent with serialized.
1690-
name = (
1691-
serialized_node.outputs[0].as_tensor.name
1692-
if _is_single_tensor_return(target)
1693-
else None # FX will generate a name for us.
1694-
)
1689+
1690+
print(target)
1691+
print(target.__name__)
1692+
print(target.name)
1693+
1694+
name = serialized_node.name
1695+
1696+
print(name)
1697+
1698+
if name == "split_tensor":
1699+
print(serialized_node)
1700+
print(serialized_node.inputs)
1701+
print(serialized_node.outputs)
1702+
16951703
args, kwargs = self.deserialize_inputs(target, serialized_node)
16961704
fx_node = self.graph.create_node(
16971705
"call_function", target, args, kwargs, name
16981706
)
16991707
self.deserialize_outputs(serialized_node, fx_node)
1708+
1709+
if name == "split_tensor":
1710+
print(fx_node)
1711+
print(fx_node.args)
1712+
print(fx_node.kwargs)
17001713
else:
17011714
raise SerializeError(
17021715
f"Unsupported target type for node {serialized_node}: {target}"

exir/serde/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class NamedArgument:
195195

196196
@dataclass
197197
class Node:
198+
name: str
198199
target: str
199200
inputs: List[NamedArgument]
200201
outputs: List[Argument]

exir/serde/serialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
8989

9090
if node.target is memory.alloc:
9191
ex_node = schema.Node(
92+
name=node.name,
9293
target="memory.alloc",
9394
inputs=self.serialize_alloc_inputs(node.args),
9495
outputs=self.serialize_arbitrary_outputs(node),
@@ -99,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
99100
elif isinstance(node.target, EdgeOpOverload):
100101
assert node.target._op is not None
101102
ex_node = schema.Node(
103+
name=node.name,
102104
target=self.serialize_operator(node.target),
103105
# pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
104106
# `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
@@ -111,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
111113
return
112114
elif node.target is delegate.executorch_call_delegate:
113115
ex_node = schema.Node(
116+
name=node.name,
114117
target=self.serialize_operator(node.target),
115118
inputs=self.serialize_call_delegate_inputs(node.args),
116119
outputs=self.serialize_arbitrary_outputs(node),

exir/tests/test_serde.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def check_ep(
4242
ep1: TorchExportedProgram,
4343
ep2: TorchExportedProgram,
4444
inputs: Tuple[exir.Value, ...],
45+
compare_closeness: bool = False,
4546
) -> None:
4647
"""
4748
Checks if two graphs are equivalent
@@ -55,15 +56,53 @@ def check_ep(
5556
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
5657
self.assertTrue(torch.allclose(orig, loaded))
5758

59+
# print node names in ep1 and ep2 seperately
60+
print("---------------------")
61+
print("ep1")
62+
print(len(ep1.graph.nodes))
63+
for node in ep1.graph.nodes:
64+
print(node.name)
65+
print("**************")
66+
print("ep2")
67+
print(len(ep2.graph.nodes))
68+
for node in ep2.graph.nodes:
69+
print(node.name)
70+
print("____________________")
71+
72+
if compare_closeness:
73+
self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes))
74+
for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes):
75+
self.assertEqual(node_a.target, node_b.target)
76+
self.assertEqual(node_a.name, node_b.name)
77+
self.assertEqual(node_a.type, node_b.type)
78+
self.assertEqual(node_a.op, node_b.op)
79+
if node_a.op != "call_function":
80+
continue
81+
82+
self.assertEqual(
83+
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
84+
)
85+
from_node_a = node_a.meta.get("from_node")
86+
from_node_b = node_b.meta.get("from_node")
87+
88+
if from_node_a is None:
89+
self.assertIsNone(from_node_b)
90+
else:
91+
self.assertIsNotNone(from_node_b)
92+
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
93+
self.assertEqual(
94+
node_source_a.to_dict(), node_source_b.to_dict()
95+
)
96+
5897
# pyre-ignore
5998
def check_serde(self, m, inputs, check_executorch=True) -> None:
6099
aten = export(m, inputs, strict=True)
61100
aten_new = deserialize(serialize(aten))
62-
self.check_ep(aten, aten_new, inputs)
101+
self.check_ep(aten, aten_new, inputs, compare_closeness=True)
63102

64103
edge = to_edge(aten)
65104
edge_new = deserialize(serialize(edge.exported_program()))
66-
self.check_ep(edge.exported_program(), edge_new, inputs)
105+
self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True)
67106

68107
buffer = io.BytesIO()
69108
exir.save(edge.exported_program(), buffer)

0 commit comments

Comments
 (0)