File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size):
508
508
spatial_axis = get_spatial_dim (output_fmt )
509
509
import math
510
510
511
+ inpt = torch .randn ((batch_size , * input_size ))
511
512
output , intermediates = model .forward_intermediates (
512
- torch . randn (( batch_size , * input_size )) ,
513
+ inpt ,
513
514
output_fmt = output_fmt ,
514
515
)
515
516
assert len (expected_channels ) == len (intermediates )
@@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size):
521
522
assert o .shape [0 ] == batch_size
522
523
assert not torch .isnan (o ).any ()
523
524
525
+ output2 = model .forward_features (inpt )
526
+ assert torch .allclose (output , output2 )
527
+
524
528
525
529
def _create_fx_model (model , train = False ):
526
530
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
You can’t perform that action at this time.
0 commit comments