@@ -9559,13 +9559,13 @@ def forward(self, x):
95599559 [None , 1 , 3 ], # channels
95609560 [16 , 32 ], # n_fft
95619561 [5 , 9 ], # num_frames
9562- [None , 4 , 5 ], # hop_length
9562+ [None , 5 ], # hop_length
95639563 [None , 10 , 8 ], # win_length
95649564 [None , torch .hann_window ], # window
95659565 [False , True ], # center
95669566 [False , True ], # normalized
95679567 [None , False , True ], # onesided
9568- [None , 30 , 40 ], # length
9568+ [None , "shorter" , "larger" ], # length
95699569 [False , True ], # return_complex
95709570 )
95719571 )
@@ -9576,9 +9576,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95769576 if hop_length is None and win_length is not None :
95779577 pytest .skip ("If win_length is set then we must set hop_length and 0 < hop_length <= win_length" )
95789578
9579+ # Compute input_shape to generate test case
95799580 freq = n_fft // 2 + 1 if onesided else n_fft
95809581 input_shape = (channels , freq , num_frames ) if channels else (freq , num_frames )
95819582
9583+ # If not set,c ompute hop_length for capturing errors
9584+ if hop_length is None :
9585+ hop_length = n_fft // 4
9586+
9587+ if length == "shorter" :
9588+ length = n_fft // 2 + hop_length * (num_frames - 1 )
9589+ elif length == "larger" :
9590+ length = n_fft * 3 // 2 + hop_length * (num_frames - 1 )
9591+
95829592 class ISTFTModel (torch .nn .Module ):
95839593 def forward (self , x ):
95849594 applied_window = window (win_length ) if window and win_length else None
@@ -9598,7 +9608,7 @@ def forward(self, x):
95989608 else :
95999609 return torch .real (x )
96009610
9601- if win_length and center is False :
9611+ if ( center is False and win_length ) or ( center and win_length and length ) :
96029612 # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
96039613 with pytest .raises (RuntimeError , match = "istft\(.*\) window overlap add min: 1" ):
96049614 TorchBaseTest .run_compare_torch (
@@ -9607,7 +9617,7 @@ def forward(self, x):
96079617 backend = backend ,
96089618 compute_unit = compute_unit
96099619 )
9610- elif length is not None and return_complex is True :
9620+ elif length and return_complex :
96119621 with pytest .raises (ValueError , match = "New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`" ):
96129622 TorchBaseTest .run_compare_torch (
96139623 input_shape ,
0 commit comments