1
+ import json
2
+ import os
1
3
import re
2
4
import subprocess
5
+ from pathlib import Path
3
6
4
7
import pytest
5
8
import torch
16
19
)
17
20
18
21
22
+ @pytest .fixture
23
+ def with_ffmpeg_debug_logs ():
24
+ # Fixture that sets the ffmpeg logs to DEBUG mode
25
+ previous_log_level = os .environ .get ("TORCHCODEC_FFMPEG_LOG_LEVEL" , "QUIET" )
26
+ os .environ ["TORCHCODEC_FFMPEG_LOG_LEVEL" ] = "DEBUG"
27
+ yield
28
+ os .environ ["TORCHCODEC_FFMPEG_LOG_LEVEL" ] = previous_log_level
29
+
30
+
31
+ def validate_frames_properties (* , actual : Path , expected : Path ):
32
+ # actual and expected are files containing encoded audio data. We call
33
+ # `ffprobe` on both, and assert that the frame properties match (pts,
34
+ # duration, etc.)
35
+
36
+ frames_actual , frames_expected = (
37
+ json .loads (
38
+ subprocess .run (
39
+ [
40
+ "ffprobe" ,
41
+ "-v" ,
42
+ "error" ,
43
+ "-hide_banner" ,
44
+ "-select_streams" ,
45
+ "a:0" ,
46
+ "-show_frames" ,
47
+ "-of" ,
48
+ "json" ,
49
+ f"{ f } " ,
50
+ ],
51
+ check = True ,
52
+ capture_output = True ,
53
+ text = True ,
54
+ ).stdout
55
+ )["frames" ]
56
+ for f in (actual , expected )
57
+ )
58
+
59
+ # frames_actual and frames_expected are both a list of dicts, each dict
60
+ # corresponds to a frame and each key-value pair corresponds to a frame
61
+ # property like pts, nb_samples, etc., similar to the AVFrame fields.
62
+ assert isinstance (frames_actual , list )
63
+ assert all (isinstance (d , dict ) for d in frames_actual )
64
+
65
+ assert len (frames_actual ) > 3 # arbitrary sanity check
66
+ assert len (frames_actual ) == len (frames_expected )
67
+
68
+ # non-exhaustive list of the props we want to test for:
69
+ required_props = (
70
+ "pts" ,
71
+ "pts_time" ,
72
+ "sample_fmt" ,
73
+ "nb_samples" ,
74
+ "channels" ,
75
+ "duration" ,
76
+ "duration_time" ,
77
+ )
78
+
79
+ for frame_index , (d_actual , d_expected ) in enumerate (
80
+ zip (frames_actual , frames_expected )
81
+ ):
82
+ if get_ffmpeg_major_version () >= 6 :
83
+ assert all (required_prop in d_expected for required_prop in required_props )
84
+
85
+ for prop in d_expected :
86
+ if prop == "pkt_pos" :
87
+ # pkt_pos is the position of the packet *in bytes* in its
88
+ # stream. We don't always match FFmpeg exactly on this,
89
+ # typically on compressed formats like mp3. It's probably
90
+ # because we are not writing the exact same headers, or
91
+ # something like this. In any case, this doesn't seem to be
92
+ # critical.
93
+ continue
94
+ assert (
95
+ d_actual [prop ] == d_expected [prop ]
96
+ ), f"\n Comparing: { actual } \n against reference: { expected } ,\n the { prop } property is different at frame { frame_index } :"
97
+
98
+
19
99
class TestAudioEncoder :
20
100
21
101
def decode (self , source ) -> torch .Tensor :
22
102
if isinstance (source , TestContainerFile ):
23
103
source = str (source .path )
24
- return AudioDecoder (source ).get_all_samples (). data
104
+ return AudioDecoder (source ).get_all_samples ()
25
105
26
106
def test_bad_input (self ):
27
107
with pytest .raises (ValueError , match = "Expected samples to be a Tensor" ):
@@ -63,12 +143,12 @@ def test_bad_input_parametrized(self, method, tmp_path):
63
143
else dict (format = "mp3" )
64
144
)
65
145
66
- decoder = AudioEncoder (self .decode (NASA_AUDIO_MP3 ), sample_rate = 10 )
146
+ decoder = AudioEncoder (self .decode (NASA_AUDIO_MP3 ). data , sample_rate = 10 )
67
147
with pytest .raises (RuntimeError , match = "invalid sample rate=10" ):
68
148
getattr (decoder , method )(** valid_params )
69
149
70
150
decoder = AudioEncoder (
71
- self .decode (NASA_AUDIO_MP3 ), sample_rate = NASA_AUDIO_MP3 .sample_rate
151
+ self .decode (NASA_AUDIO_MP3 ). data , sample_rate = NASA_AUDIO_MP3 .sample_rate
72
152
)
73
153
with pytest .raises (RuntimeError , match = "bit_rate=-1 must be >= 0" ):
74
154
getattr (decoder , method )(** valid_params , bit_rate = - 1 )
@@ -81,7 +161,7 @@ def test_bad_input_parametrized(self, method, tmp_path):
81
161
getattr (decoder , method )(** valid_params )
82
162
83
163
decoder = AudioEncoder (
84
- self .decode (NASA_AUDIO_MP3 ), sample_rate = NASA_AUDIO_MP3 .sample_rate
164
+ self .decode (NASA_AUDIO_MP3 ). data , sample_rate = NASA_AUDIO_MP3 .sample_rate
85
165
)
86
166
for num_channels in (0 , 3 ):
87
167
with pytest .raises (
@@ -101,7 +181,7 @@ def test_round_trip(self, method, format, tmp_path):
101
181
pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
102
182
103
183
asset = NASA_AUDIO_MP3
104
- source_samples = self .decode (asset )
184
+ source_samples = self .decode (asset ). data
105
185
106
186
encoder = AudioEncoder (source_samples , sample_rate = asset .sample_rate )
107
187
@@ -116,7 +196,7 @@ def test_round_trip(self, method, format, tmp_path):
116
196
117
197
rtol , atol = (0 , 1e-4 ) if format == "wav" else (None , None )
118
198
torch .testing .assert_close (
119
- self .decode (encoded_source ), source_samples , rtol = rtol , atol = atol
199
+ self .decode (encoded_source ). data , source_samples , rtol = rtol , atol = atol
120
200
)
121
201
122
202
@pytest .mark .skipif (in_fbcode (), reason = "TODO: enable ffmpeg CLI" )
@@ -125,7 +205,17 @@ def test_round_trip(self, method, format, tmp_path):
125
205
@pytest .mark .parametrize ("num_channels" , (None , 1 , 2 ))
126
206
@pytest .mark .parametrize ("format" , ("mp3" , "wav" , "flac" ))
127
207
@pytest .mark .parametrize ("method" , ("to_file" , "to_tensor" ))
128
- def test_against_cli (self , asset , bit_rate , num_channels , format , method , tmp_path ):
208
+ def test_against_cli (
209
+ self ,
210
+ asset ,
211
+ bit_rate ,
212
+ num_channels ,
213
+ format ,
214
+ method ,
215
+ tmp_path ,
216
+ capfd ,
217
+ with_ffmpeg_debug_logs ,
218
+ ):
129
219
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
130
220
# that both decoded outputs are equal
131
221
@@ -144,14 +234,25 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
144
234
check = True ,
145
235
)
146
236
147
- encoder = AudioEncoder (self .decode (asset ), sample_rate = asset .sample_rate )
237
+ encoder = AudioEncoder (self .decode (asset ).data , sample_rate = asset .sample_rate )
238
+
148
239
params = dict (bit_rate = bit_rate , num_channels = num_channels )
149
240
if method == "to_file" :
150
241
encoded_by_us = tmp_path / f"output.{ format } "
151
242
encoder .to_file (dest = str (encoded_by_us ), ** params )
152
243
else :
153
244
encoded_by_us = encoder .to_tensor (format = format , ** params )
154
245
246
+ captured = capfd .readouterr ()
247
+ if format == "wav" :
248
+ assert "Timestamps are unset in a packet" not in captured .err
249
+ if format == "mp3" :
250
+ assert "Queue input is backward in time" not in captured .err
251
+ if format in ("flac" , "wav" ):
252
+ assert "Encoder did not produce proper pts" not in captured .err
253
+ if format in ("flac" , "mp3" ):
254
+ assert "Application provided invalid" not in captured .err
255
+
155
256
if format == "wav" :
156
257
rtol , atol = 0 , 1e-4
157
258
elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2 :
@@ -162,12 +263,22 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
162
263
rtol , atol = 0 , 1e-3
163
264
else :
164
265
rtol , atol = None , None
266
+ samples_by_us = self .decode (encoded_by_us )
267
+ samples_by_ffmpeg = self .decode (encoded_by_ffmpeg )
165
268
torch .testing .assert_close (
166
- self . decode ( encoded_by_ffmpeg ) ,
167
- self . decode ( encoded_by_us ) ,
269
+ samples_by_us . data ,
270
+ samples_by_ffmpeg . data ,
168
271
rtol = rtol ,
169
272
atol = atol ,
170
273
)
274
+ assert samples_by_us .pts_seconds == samples_by_ffmpeg .pts_seconds
275
+ assert samples_by_us .duration_seconds == samples_by_ffmpeg .duration_seconds
276
+ assert samples_by_us .sample_rate == samples_by_ffmpeg .sample_rate
277
+
278
+ if method == "to_file" :
279
+ validate_frames_properties (actual = encoded_by_us , expected = encoded_by_ffmpeg )
280
+ else :
281
+ assert method == "to_tensor" , "wrong test parametrization!"
171
282
172
283
@pytest .mark .parametrize ("asset" , (NASA_AUDIO_MP3 , SINE_MONO_S32 ))
173
284
@pytest .mark .parametrize ("bit_rate" , (None , 0 , 44_100 , 999_999_999 ))
@@ -179,7 +290,7 @@ def test_to_tensor_against_to_file(
179
290
if get_ffmpeg_major_version () == 4 and format == "wav" :
180
291
pytest .skip ("Swresample with FFmpeg 4 doesn't work on wav files" )
181
292
182
- encoder = AudioEncoder (self .decode (asset ), sample_rate = asset .sample_rate )
293
+ encoder = AudioEncoder (self .decode (asset ). data , sample_rate = asset .sample_rate )
183
294
184
295
params = dict (bit_rate = bit_rate , num_channels = num_channels )
185
296
encoded_file = tmp_path / f"output.{ format } "
@@ -189,7 +300,7 @@ def test_to_tensor_against_to_file(
189
300
)
190
301
191
302
torch .testing .assert_close (
192
- self .decode (encoded_file ), self .decode (encoded_tensor )
303
+ self .decode (encoded_file ). data , self .decode (encoded_tensor ). data
193
304
)
194
305
195
306
def test_encode_to_tensor_long_output (self ):
@@ -205,7 +316,7 @@ def test_encode_to_tensor_long_output(self):
205
316
INITIAL_TENSOR_SIZE = 10_000_000
206
317
assert encoded_tensor .numel () > INITIAL_TENSOR_SIZE
207
318
208
- torch .testing .assert_close (self .decode (encoded_tensor ), samples )
319
+ torch .testing .assert_close (self .decode (encoded_tensor ). data , samples )
209
320
210
321
def test_contiguity (self ):
211
322
# Ensure that 2 waveforms with the same values are encoded in the same
@@ -262,4 +373,4 @@ def test_num_channels(
262
373
263
374
if num_channels_output is None :
264
375
num_channels_output = num_channels_input
265
- assert self .decode (encoded_source ).shape [0 ] == num_channels_output
376
+ assert self .decode (encoded_source ).data . shape [0 ] == num_channels_output
0 commit comments