@@ -707,6 +707,117 @@ def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
707
707
return cast (list [int ], permute_node .kwargs ["dim" ])
708
708
709
709
710
+ @register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
711
+ class RemoveSqueezeUnsqueezeAroundElementwiseOps (ExportPass ):
712
+ """
713
+ Looks for subgraphs of the form:
714
+ unsqueeze -> [op] -> squeeze
715
+ and removes the unsqueeze and squeeze nodes by reshaping the intermediate ops. Only
716
+ handles simple chain of ops as intermediate for now.
717
+
718
+ The pass works on view ops instead of unsqueeze and squeeze directly, thus it
719
+ should be run after the squeeze/unsqueeze->view lowering.
720
+ """
721
+
722
+ intermediate_ops : set [EdgeOpOverload ] = {
723
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
724
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
725
+ exir_ops .edge .cadence .quantize_per_tensor .default ,
726
+ exir_ops .edge .cadence .dequantize_per_tensor .default ,
727
+ # Ops that require special handling:
728
+ exir_ops .edge .aten .slice_copy .Tensor ,
729
+ }
730
+
731
+ def find_unsqueeze_dim (self , view_node : Node ) -> Optional [int ]:
732
+ """
733
+ Return the unsqueeze dim if the given view_copy op unsqueezes the input tensor,
734
+ if not return None.
735
+ """
736
+ input_node = cast (Node , get_arg (view_node , 0 , "input" ))
737
+ input_shape = input_node .meta ["val" ].shape
738
+ output_shape = view_node .meta ["val" ].shape
739
+ if len (output_shape ) != len (input_shape ) + 1 :
740
+ return None
741
+ for dim in range (len (output_shape )):
742
+ if output_shape == input_shape [:dim ] + (1 ,) + input_shape [dim :]:
743
+ return dim
744
+ return None
745
+
746
+ def find_ancestor_squeeze (self , node : Node , squeeze_dim : int ) -> Optional [Node ]:
747
+ """
748
+ Traverse up from the given node until finding a squeeze node with the given
749
+ squeeze_dim. If no such node is found, return None.
750
+ """
751
+ while True :
752
+ # Only handle simple chains for now
753
+ if len (node .users ) != 1 :
754
+ return None
755
+ if node .target in self .intermediate_ops :
756
+ node = cast (Node , get_arg (node , 0 , "input" ))
757
+ elif node .target == exir_ops .edge .aten .view_copy .default :
758
+ input_node = cast (Node , get_arg (node , 0 , "input" ))
759
+ input_shape = input_node .meta ["val" ].shape
760
+ output_shape = node .meta ["val" ].shape
761
+ # Check if the node is a squeeze op.
762
+ if (
763
+ len (input_shape ) != len (output_shape ) + 1
764
+ or input_shape
765
+ != output_shape [:squeeze_dim ] + (1 ,) + output_shape [squeeze_dim :]
766
+ ):
767
+ return None
768
+ return node
769
+ else :
770
+ return None
771
+
772
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
773
+ changed = False
774
+
775
+ # Traverse the graph looking for unsqueeze-like view ops.
776
+ for node in graph_module .graph .find_nodes (
777
+ op = "call_function" , target = exir_ops .edge .aten .view_copy .default
778
+ ):
779
+ unsqueeze_dim = self .find_unsqueeze_dim (node )
780
+ if unsqueeze_dim is None :
781
+ continue
782
+
783
+ input_node = cast (Node , get_arg (node , 0 , "input" ))
784
+ squeeze_node = self .find_ancestor_squeeze (input_node , unsqueeze_dim )
785
+ if squeeze_node is None :
786
+ continue
787
+
788
+ # Chain is found. Remove view ops and update the intermediate ops traversing
789
+ # the chain.
790
+ assert len (squeeze_node .users ) == 1
791
+ node = next (iter (squeeze_node .users ))
792
+
793
+ # Skip first view_copy.
794
+ squeeze_node .replace_all_uses_with (
795
+ cast (Node , get_arg (squeeze_node , 0 , "input" ))
796
+ )
797
+
798
+ # Go down the chain and update the intermediate ops if needed.
799
+ while node .target != exir_ops .edge .aten .view_copy .default :
800
+ if node .target == exir_ops .edge .aten .slice_copy .Tensor :
801
+ slice_dim = cast (int , get_arg (node , 1 , "dim" , default = 0 ))
802
+ if slice_dim < 0 :
803
+ slice_dim += len (node .meta ["val" ].shape )
804
+ if slice_dim >= unsqueeze_dim :
805
+ set_arg (node , 1 , "dim" , slice_dim + 1 )
806
+ assert len (node .users ) == 1
807
+ node = next (iter (node .users ))
808
+
809
+ # Skip final view_copy.
810
+ node .replace_all_uses_with (cast (Node , get_arg (node , 0 , "input" )))
811
+
812
+ changed = True
813
+
814
+ if changed :
815
+ graph_module .graph .eliminate_dead_code ()
816
+ graph_module .recompile ()
817
+
818
+ return PassResult (graph_module , changed )
819
+
820
+
710
821
@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
711
822
class RemoveBranchedQuantDequant (ExportPass ):
712
823
"""
0 commit comments