@@ -595,33 +595,78 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
595
595
self .assertEqual (counter , 1 )
596
596
597
597
def test_compile_fix_broken_ops (self ) -> None :
598
- # When pass an input of more than 4 dimensions to Linear
599
- # aten._unsafe_view is used under the hood
600
- x = torch .randn ([2 , 3 , 4 , 5 ])
601
- model : torch .nn .Linear = torch .nn .Linear (5 , 5 )
602
-
603
- class Foo (torch .nn .Module ):
604
- def __init__ (self ):
598
+ class ExportableLoop (nn .Module ):
599
+ def __init__ (self , hidden_size , out_channels ):
605
600
super ().__init__ ()
606
- self .model = model
607
-
608
- def forward (self , inp : torch .Tensor ) -> torch .Tensor :
609
- return self .model (inp )
610
-
611
- f = Foo ()
601
+ self .hidden_size = hidden_size
602
+ self .B = nn .Parameter (torch .randn (hidden_size , 1 )) # (H, in_channels)
603
+ self .C = nn .Parameter (
604
+ torch .randn (out_channels , hidden_size )
605
+ ) # (C_out, H)
606
+ A = torch .randn (2 , hidden_size )
607
+ self .A_real = nn .Parameter (A [0 ].clone ())
608
+ self .A_imag = nn .Parameter (A [1 ].clone ())
609
+
610
+ def update_state (self , h , x_t ):
611
+ # h: [B, 2, H], x_t: [B, H]
612
+ hr , hi = h [:, 0 , :], h [:, 1 , :] # [B, H]
613
+ hrn = hr * self .A_real - hi * self .A_imag + x_t # [B, H]
614
+ hin = hi * self .A_real + hr * self .A_imag # [B, H]
615
+ hn = torch .stack ([hrn , hin ], dim = 1 ) # [B, 2, H]
616
+ return hn , hrn
617
+
618
+ def forward (self , u ):
619
+ # u: [B, 1, T]
620
+ x = torch .matmul (self .B , u ) # (B, H, T)
621
+ B , H , T = x .shape
622
+
623
+ h = torch .zeros (B , 2 , H , device = x .device , dtype = x .dtype ) # [B, 2, H]
624
+ h_accum = torch .zeros (
625
+ B , H , T , device = x .device , dtype = x .dtype
626
+ ) # [B, H, T]
627
+ i = torch .tensor (0 , device = x .device , dtype = torch .int64 )
628
+ one = torch .tensor (1 , device = x .device , dtype = torch .int64 )
629
+
630
+ def cond (i , h , h_accum ):
631
+ return i < T
632
+
633
+ def body (i , h , h_accum ):
634
+ x_t = x .index_select (- 1 , i .unsqueeze (0 )).squeeze (
635
+ - 1
636
+ ) # ✅ safe for export
637
+ h , hr = self .update_state (h , x_t ) # h: [B, 2, H], hr: [B, H]
638
+ h_accum = h_accum .index_copy (
639
+ - 1 , i .unsqueeze (0 ), hr .unsqueeze (- 1 )
640
+ ) # [B, H, T]
641
+ i_next = i + one
642
+ return i_next , h , h_accum
643
+
644
+ _ , h , h_accum = torch ._higher_order_ops .while_loop (
645
+ cond , body , (i , h , h_accum )
646
+ )
647
+ y = torch .matmul (self .C , h_accum ).transpose (0 , 1 ) # (B, C_out, T)
648
+ return y
612
649
613
- # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
650
+ # Instantiate and export
651
+ model = ExportableLoop (hidden_size = 128 , out_channels = 10 )
652
+ inp = torch .randn (1 , 1 , 32 ) # (B, in_channels=1, T=32)
653
+ ep = export (model , (inp ,))
614
654
prog = to_edge (
615
- export ( f , ( x ,), strict = True ) ,
655
+ ep ,
616
656
compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
617
657
)
618
658
gm = prog .exported_program ().graph_module
619
659
count_after = 0
620
660
for node in gm .graph .nodes :
621
- if node .target == torch .ops .aten ._unsafe_view .default :
661
+ if (
662
+ node .target == torch .ops .aten .squeeze .dims
663
+ or node .target == torch .ops .aten .select .int
664
+ ):
622
665
count_after += 1
623
666
self .assertEqual (count_after , 0 )
624
- self .assertTrue (torch .allclose (prog .exported_program ().module ()(x ), f (x )))
667
+ self .assertTrue (
668
+ torch .allclose (prog .exported_program ().module ()(inp ), model (inp ))
669
+ )
625
670
626
671
def test_convert_symb_ops (self ) -> None :
627
672
class Foo (torch .nn .Module ):
0 commit comments