Skip to content

Commit ffb65f6

Browse files
authored
Encoder: properly set frame pts values (#726)
1 parent 62ee42d commit ffb65f6

File tree

4 files changed

+128
-17
lines changed

4 files changed

+128
-17
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ void AudioEncoder::encode() {
293293
encodeInnerLoop(autoAVPacket, convertedAVFrame);
294294

295295
numEncodedSamples += numSamplesToEncode;
296-
// TODO-ENCODING set frame pts correctly, and test against it.
297-
// avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
296+
avFrame->pts += static_cast<int64_t>(numSamplesToEncode);
298297
}
299298
TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong.");
300299

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ UniqueAVFrame convertAudioAVFrameSamples(
274274
convertedAVFrame,
275275
"Could not allocate frame for sample format conversion.");
276276

277+
convertedAVFrame->pts = srcAVFrame->pts;
277278
convertedAVFrame->format = static_cast<int>(outSampleFormat);
278279

279280
convertedAVFrame->sample_rate = outSampleRate;

src/torchcodec/_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class AudioSamples(Iterable):
125125
pts_seconds: float
126126
"""The :term:`pts` of the first sample, in seconds."""
127127
duration_seconds: float
128-
"""The duration of the sampleas, in seconds."""
128+
"""The duration of the samples, in seconds."""
129129
sample_rate: int
130130
"""The sample rate of the samples, in Hz."""
131131

test/test_encoders.py

Lines changed: 125 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import json
2+
import os
13
import re
24
import subprocess
5+
from pathlib import Path
36

47
import pytest
58
import torch
@@ -16,12 +19,89 @@
1619
)
1720

1821

22+
@pytest.fixture
23+
def with_ffmpeg_debug_logs():
24+
# Fixture that sets the ffmpeg logs to DEBUG mode
25+
previous_log_level = os.environ.get("TORCHCODEC_FFMPEG_LOG_LEVEL", "QUIET")
26+
os.environ["TORCHCODEC_FFMPEG_LOG_LEVEL"] = "DEBUG"
27+
yield
28+
os.environ["TORCHCODEC_FFMPEG_LOG_LEVEL"] = previous_log_level
29+
30+
31+
def validate_frames_properties(*, actual: Path, expected: Path):
32+
# actual and expected are files containing encoded audio data. We call
33+
# `ffprobe` on both, and assert that the frame properties match (pts,
34+
# duration, etc.)
35+
36+
frames_actual, frames_expected = (
37+
json.loads(
38+
subprocess.run(
39+
[
40+
"ffprobe",
41+
"-v",
42+
"error",
43+
"-hide_banner",
44+
"-select_streams",
45+
"a:0",
46+
"-show_frames",
47+
"-of",
48+
"json",
49+
f"{f}",
50+
],
51+
check=True,
52+
capture_output=True,
53+
text=True,
54+
).stdout
55+
)["frames"]
56+
for f in (actual, expected)
57+
)
58+
59+
# frames_actual and frames_expected are both a list of dicts, each dict
60+
# corresponds to a frame and each key-value pair corresponds to a frame
61+
# property like pts, nb_samples, etc., similar to the AVFrame fields.
62+
assert isinstance(frames_actual, list)
63+
assert all(isinstance(d, dict) for d in frames_actual)
64+
65+
assert len(frames_actual) > 3 # arbitrary sanity check
66+
assert len(frames_actual) == len(frames_expected)
67+
68+
# non-exhaustive list of the props we want to test for:
69+
required_props = (
70+
"pts",
71+
"pts_time",
72+
"sample_fmt",
73+
"nb_samples",
74+
"channels",
75+
"duration",
76+
"duration_time",
77+
)
78+
79+
for frame_index, (d_actual, d_expected) in enumerate(
80+
zip(frames_actual, frames_expected)
81+
):
82+
if get_ffmpeg_major_version() >= 6:
83+
assert all(required_prop in d_expected for required_prop in required_props)
84+
85+
for prop in d_expected:
86+
if prop == "pkt_pos":
87+
# pkt_pos is the position of the packet *in bytes* in its
88+
# stream. We don't always match FFmpeg exactly on this,
89+
# typically on compressed formats like mp3. It's probably
90+
# because we are not writing the exact same headers, or
91+
# something like this. In any case, this doesn't seem to be
92+
# critical.
93+
continue
94+
assert (
95+
d_actual[prop] == d_expected[prop]
96+
), f"\nComparing: {actual}\nagainst reference: {expected},\nthe {prop} property is different at frame {frame_index}:"
97+
98+
1999
class TestAudioEncoder:
20100

21101
def decode(self, source) -> torch.Tensor:
22102
if isinstance(source, TestContainerFile):
23103
source = str(source.path)
24-
return AudioDecoder(source).get_all_samples().data
104+
return AudioDecoder(source).get_all_samples()
25105

26106
def test_bad_input(self):
27107
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
@@ -63,12 +143,12 @@ def test_bad_input_parametrized(self, method, tmp_path):
63143
else dict(format="mp3")
64144
)
65145

66-
decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3), sample_rate=10)
146+
decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3).data, sample_rate=10)
67147
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
68148
getattr(decoder, method)(**valid_params)
69149

70150
decoder = AudioEncoder(
71-
self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate
151+
self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate
72152
)
73153
with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"):
74154
getattr(decoder, method)(**valid_params, bit_rate=-1)
@@ -81,7 +161,7 @@ def test_bad_input_parametrized(self, method, tmp_path):
81161
getattr(decoder, method)(**valid_params)
82162

83163
decoder = AudioEncoder(
84-
self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate
164+
self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate
85165
)
86166
for num_channels in (0, 3):
87167
with pytest.raises(
@@ -101,7 +181,7 @@ def test_round_trip(self, method, format, tmp_path):
101181
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
102182

103183
asset = NASA_AUDIO_MP3
104-
source_samples = self.decode(asset)
184+
source_samples = self.decode(asset).data
105185

106186
encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate)
107187

@@ -116,7 +196,7 @@ def test_round_trip(self, method, format, tmp_path):
116196

117197
rtol, atol = (0, 1e-4) if format == "wav" else (None, None)
118198
torch.testing.assert_close(
119-
self.decode(encoded_source), source_samples, rtol=rtol, atol=atol
199+
self.decode(encoded_source).data, source_samples, rtol=rtol, atol=atol
120200
)
121201

122202
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
@@ -125,7 +205,17 @@ def test_round_trip(self, method, format, tmp_path):
125205
@pytest.mark.parametrize("num_channels", (None, 1, 2))
126206
@pytest.mark.parametrize("format", ("mp3", "wav", "flac"))
127207
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
128-
def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_path):
208+
def test_against_cli(
209+
self,
210+
asset,
211+
bit_rate,
212+
num_channels,
213+
format,
214+
method,
215+
tmp_path,
216+
capfd,
217+
with_ffmpeg_debug_logs,
218+
):
129219
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
130220
# that both decoded outputs are equal
131221

@@ -144,14 +234,25 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
144234
check=True,
145235
)
146236

147-
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
237+
encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate)
238+
148239
params = dict(bit_rate=bit_rate, num_channels=num_channels)
149240
if method == "to_file":
150241
encoded_by_us = tmp_path / f"output.{format}"
151242
encoder.to_file(dest=str(encoded_by_us), **params)
152243
else:
153244
encoded_by_us = encoder.to_tensor(format=format, **params)
154245

246+
captured = capfd.readouterr()
247+
if format == "wav":
248+
assert "Timestamps are unset in a packet" not in captured.err
249+
if format == "mp3":
250+
assert "Queue input is backward in time" not in captured.err
251+
if format in ("flac", "wav"):
252+
assert "Encoder did not produce proper pts" not in captured.err
253+
if format in ("flac", "mp3"):
254+
assert "Application provided invalid" not in captured.err
255+
155256
if format == "wav":
156257
rtol, atol = 0, 1e-4
157258
elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2:
@@ -162,12 +263,22 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa
162263
rtol, atol = 0, 1e-3
163264
else:
164265
rtol, atol = None, None
266+
samples_by_us = self.decode(encoded_by_us)
267+
samples_by_ffmpeg = self.decode(encoded_by_ffmpeg)
165268
torch.testing.assert_close(
166-
self.decode(encoded_by_ffmpeg),
167-
self.decode(encoded_by_us),
269+
samples_by_us.data,
270+
samples_by_ffmpeg.data,
168271
rtol=rtol,
169272
atol=atol,
170273
)
274+
assert samples_by_us.pts_seconds == samples_by_ffmpeg.pts_seconds
275+
assert samples_by_us.duration_seconds == samples_by_ffmpeg.duration_seconds
276+
assert samples_by_us.sample_rate == samples_by_ffmpeg.sample_rate
277+
278+
if method == "to_file":
279+
validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg)
280+
else:
281+
assert method == "to_tensor", "wrong test parametrization!"
171282

172283
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
173284
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
@@ -179,7 +290,7 @@ def test_to_tensor_against_to_file(
179290
if get_ffmpeg_major_version() == 4 and format == "wav":
180291
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
181292

182-
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
293+
encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate)
183294

184295
params = dict(bit_rate=bit_rate, num_channels=num_channels)
185296
encoded_file = tmp_path / f"output.{format}"
@@ -189,7 +300,7 @@ def test_to_tensor_against_to_file(
189300
)
190301

191302
torch.testing.assert_close(
192-
self.decode(encoded_file), self.decode(encoded_tensor)
303+
self.decode(encoded_file).data, self.decode(encoded_tensor).data
193304
)
194305

195306
def test_encode_to_tensor_long_output(self):
@@ -205,7 +316,7 @@ def test_encode_to_tensor_long_output(self):
205316
INITIAL_TENSOR_SIZE = 10_000_000
206317
assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE
207318

208-
torch.testing.assert_close(self.decode(encoded_tensor), samples)
319+
torch.testing.assert_close(self.decode(encoded_tensor).data, samples)
209320

210321
def test_contiguity(self):
211322
# Ensure that 2 waveforms with the same values are encoded in the same
@@ -262,4 +373,4 @@ def test_num_channels(
262373

263374
if num_channels_output is None:
264375
num_channels_output = num_channels_input
265-
assert self.decode(encoded_source).shape[0] == num_channels_output
376+
assert self.decode(encoded_source).data.shape[0] == num_channels_output

0 commit comments

Comments
 (0)