@@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node):
504
504
assert len (node .kwargs ) == 0
505
505
meta_val = node .meta ["val" ]
506
506
ex_node = Node (
507
+ name = node .name ,
507
508
target = self .serialize_operator (node .target ),
508
509
inputs = self .serialize_sym_op_inputs (node .target , node .args ),
509
510
outputs = [
@@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node):
517
518
assert len (node .kwargs ) == 0
518
519
meta_val = node .meta ["val" ]
519
520
ex_node = Node (
521
+ name = node .name ,
520
522
target = self .serialize_operator (node .target ),
521
523
inputs = self .serialize_sym_op_inputs (node .target , node .args ),
522
524
outputs = [
@@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node):
528
530
)
529
531
elif isinstance (node .target , torch ._ops .OpOverload ):
530
532
ex_node = Node (
533
+ name = node .name ,
531
534
target = self .serialize_operator (node .target ),
532
535
inputs = self .serialize_inputs (node .target , node .args , node .kwargs ),
533
536
outputs = self .serialize_outputs (node ),
@@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node):
536
539
)
537
540
elif isinstance (node .target , torch ._ops .HigherOrderOperator ):
538
541
ex_node = Node (
542
+ name = node .name ,
539
543
target = self .serialize_operator (node .target ),
540
544
inputs = self .serialize_hoo_inputs (node .args , node .kwargs ),
541
545
outputs = self .serialize_hoo_outputs (node ),
@@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
1658
1662
1659
1663
def deserialize_node (self , serialized_node : Node , target : Callable ) -> None :
1660
1664
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS :
1661
- name = serialized_node .outputs [ 0 ]. value . as_name
1665
+ name = serialized_node .name
1662
1666
args = self .deserialize_sym_op_inputs (serialized_node .inputs )
1663
1667
1664
1668
fx_node = self .graph .create_node ("call_function" , target , args , {}, name )
@@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
1671
1675
# have names that are consistent with serialized.
1672
1676
#
1673
1677
# HOPs don't have schema yet, just check the output lengths and as_tensor attribute
1674
- name = (
1675
- serialized_node .outputs [0 ].as_tensor .name
1676
- if len (serialized_node .outputs ) == 1
1677
- and hasattr (serialized_node .outputs [0 ], "as_tensor" )
1678
- else None
1679
- )
1678
+ name = serialized_node .name
1680
1679
fx_node = self .graph .create_node (
1681
1680
"call_function" , target , args , kwargs , name
1682
1681
)
@@ -1687,16 +1686,30 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
1687
1686
# For convenience: if this node returns a single tensor, name the
1688
1687
# newly-created node after it. This ensures that these tensor values
1689
1688
# have names that are consistent with serialized.
1690
- name = (
1691
- serialized_node .outputs [0 ].as_tensor .name
1692
- if _is_single_tensor_return (target )
1693
- else None # FX will generate a name for us.
1694
- )
1689
+
1690
+ print (target )
1691
+ print (target .__name__ )
1692
+ print (target .name )
1693
+
1694
+ name = serialized_node .name
1695
+
1696
+ print (name )
1697
+
1698
+ if name == "split_tensor" :
1699
+ print (serialized_node )
1700
+ print (serialized_node .inputs )
1701
+ print (serialized_node .outputs )
1702
+
1695
1703
args , kwargs = self .deserialize_inputs (target , serialized_node )
1696
1704
fx_node = self .graph .create_node (
1697
1705
"call_function" , target , args , kwargs , name
1698
1706
)
1699
1707
self .deserialize_outputs (serialized_node , fx_node )
1708
+
1709
+ if name == "split_tensor" :
1710
+ print (fx_node )
1711
+ print (fx_node .args )
1712
+ print (fx_node .kwargs )
1700
1713
else :
1701
1714
raise SerializeError (
1702
1715
f"Unsupported target type for node { serialized_node } : { target } "
0 commit comments