Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fdcf7cf

Browse files
committedMay 6, 2025·
Check forward_intermediates features against forward_features output
1 parent c8c4f25 commit fdcf7cf

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed
 

‎tests/test_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,9 @@ def test_model_forward_intermediates(model_name, batch_size):
508508
spatial_axis = get_spatial_dim(output_fmt)
509509
import math
510510

511+
inpt = torch.randn((batch_size, *input_size))
511512
output, intermediates = model.forward_intermediates(
512-
torch.randn((batch_size, *input_size)),
513+
inpt,
513514
output_fmt=output_fmt,
514515
)
515516
assert len(expected_channels) == len(intermediates)
@@ -521,6 +522,9 @@ def test_model_forward_intermediates(model_name, batch_size):
521522
assert o.shape[0] == batch_size
522523
assert not torch.isnan(o).any()
523524

525+
output2 = model.forward_features(inpt)
526+
assert torch.allclose(output, output2)
527+
524528

525529
def _create_fx_model(model, train=False):
526530
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode

0 commit comments

Comments
 (0)
Please sign in to comment.