diff --git a/README.md b/README.md
index 9a22ddfe..2ec32877 100644
--- a/README.md
+++ b/README.md
@@ -3,17 +3,18 @@
# TorchCodec
TorchCodec is a Python library for decoding video and audio data into PyTorch
-tensors, on CPU and CUDA GPU. It aims to be fast, easy to use, and well
-integrated into the PyTorch ecosystem. If you want to use PyTorch to train ML
-models on videos and audio, TorchCodec is how you turn these into data.
+tensors, on CPU and CUDA GPU. It also supports audio encoding, and video
+encoding will come soon! It aims to be fast, easy to use, and well integrated
+into the PyTorch ecosystem. If you want to use PyTorch to train ML models on
+videos and audio, TorchCodec is how you turn these into data.
We achieve these capabilities through:
* Pythonic APIs that mirror Python and PyTorch conventions.
-* Relying on [FFmpeg](https://www.ffmpeg.org/) to do the decoding. TorchCodec
- uses the version of FFmpeg you already have installed. FFmpeg is a mature
- library with broad coverage available on most systems. It is, however, not
- easy to use. TorchCodec abstracts FFmpeg's complexity to ensure it is used
+* Relying on [FFmpeg](https://www.ffmpeg.org/) to do the decoding / encoding.
+ TorchCodec uses the version of FFmpeg you already have installed. FFmpeg is a
+ mature library with broad coverage available on most systems. It is, however,
+ not easy to use. TorchCodec abstracts FFmpeg's complexity to ensure it is used
correctly and efficiently.
* Returning data as PyTorch tensors, ready to be fed into PyTorch transforms
or used directly to train models.
diff --git a/docs/source/api_ref_decoders.rst b/docs/source/api_ref_decoders.rst
index bb55cfae..0ae159c3 100644
--- a/docs/source/api_ref_decoders.rst
+++ b/docs/source/api_ref_decoders.rst
@@ -7,8 +7,8 @@ torchcodec.decoders
.. currentmodule:: torchcodec.decoders
-For a video decoder tutorial, see: :ref:`sphx_glr_generated_examples_basic_example.py`.
-For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_audio_decoding.py`.
+For a video decoder tutorial, see: :ref:`sphx_glr_generated_examples_decoding_basic_example.py`.
+For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_decoding_audio_decoding.py`.
.. autosummary::
diff --git a/docs/source/api_ref_encoders.rst b/docs/source/api_ref_encoders.rst
new file mode 100644
index 00000000..52c7295b
--- /dev/null
+++ b/docs/source/api_ref_encoders.rst
@@ -0,0 +1,18 @@
+.. _encoders:
+
+===================
+torchcodec.encoders
+===================
+
+.. currentmodule:: torchcodec.encoders
+
+
+For an audio decoder tutorial, see: :ref:`sphx_glr_generated_examples_encoding_audio_encoding.py`.
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: class.rst
+
+ AudioEncoder
diff --git a/docs/source/api_ref_samplers.rst b/docs/source/api_ref_samplers.rst
index c72c4d82..0cec98e6 100644
--- a/docs/source/api_ref_samplers.rst
+++ b/docs/source/api_ref_samplers.rst
@@ -6,7 +6,7 @@ torchcodec.samplers
.. currentmodule:: torchcodec.samplers
-For a tutorial, see: :ref:`sphx_glr_generated_examples_sampling.py`.
+For a tutorial, see: :ref:`sphx_glr_generated_examples_decoding_sampling.py`.
.. autosummary::
:toctree: generated/
diff --git a/docs/source/conf.py b/docs/source/conf.py
index b9d3eb58..b14dc4f4 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -68,18 +68,24 @@ class CustomGalleryExampleSortKey:
def __init__(self, src_dir):
self.src_dir = src_dir
- order = [
- "basic_example.py",
- "audio_decoding.py",
- "basic_cuda_example.py",
- "file_like.py",
- "approximate_mode.py",
- "sampling.py",
- ]
-
def __call__(self, filename):
+ if "examples/decoding" in self.src_dir:
+ order = [
+ "basic_example.py",
+ "audio_decoding.py",
+ "basic_cuda_example.py",
+ "file_like.py",
+ "approximate_mode.py",
+ "sampling.py",
+ ]
+ else:
+ assert "examples/encoding" in self.src_dir
+ order = [
+ "audio_encoding.py",
+ ]
+
try:
- return self.order.index(filename)
+ return order.index(filename)
except ValueError as e:
raise ValueError(
"Looks like you added an example in the examples/ folder?"
diff --git a/docs/source/index.rst b/docs/source/index.rst
index d6d31990..55f3edf2 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -2,21 +2,25 @@ Welcome to the TorchCodec documentation!
========================================
TorchCodec is a Python library for decoding video and audio data into PyTorch
-tensors, on CPU and CUDA GPU. It aims to be fast, easy to use, and well
-integrated into the PyTorch ecosystem. If you want to use PyTorch to train ML
-models on videos and audio, TorchCodec is how you turn these into data.
+tensors, on CPU and CUDA GPU. It also supports audio encoding, and video encoding will come soon!
+It aims to be fast, easy to use, and well integrated into the PyTorch ecosystem.
+If you want to use PyTorch to train ML models on videos and audio, TorchCodec is
+how you turn these into data.
We achieve these capabilities through:
* Pythonic APIs that mirror Python and PyTorch conventions.
-* Relying on `FFmpeg `_ to do the decoding. TorchCodec
- uses the version of FFmpeg you already have installed. FMPEG is a mature
- library with broad coverage available on most systems. It is, however, not
- easy to use. TorchCodec abstracts FFmpeg's complexity to ensure it is used
- correctly and efficiently.
+* Relying on `FFmpeg `_ to do the decoding / encoding.
+ TorchCodec uses the version of FFmpeg you already have installed. FMPEG is a
+ mature library with broad coverage available on most systems. It is, however,
+ not easy to use. TorchCodec abstracts FFmpeg's complexity to ensure it is
+ used correctly and efficiently.
* Returning data as PyTorch tensors, ready to be fed into PyTorch transforms
or used directly to train models.
+Installation instructions
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
.. grid:: 3
.. grid-item-card:: :octicon:`file-code;1em`
@@ -27,10 +31,15 @@ We achieve these capabilities through:
How to install TorchCodec
+Decoding
+^^^^^^^^
+
+.. grid:: 3
+
.. grid-item-card:: :octicon:`file-code;1em`
Getting Started with TorchCodec
:img-top: _static/img/card-background.svg
- :link: generated_examples/basic_example.html
+ :link: generated_examples/decoding/basic_example.html
:link-type: url
A simple video decoding example
@@ -38,7 +47,7 @@ We achieve these capabilities through:
.. grid-item-card:: :octicon:`file-code;1em`
Audio Decoding
:img-top: _static/img/card-background.svg
- :link: generated_examples/audio_decoding.html
+ :link: generated_examples/decoding/audio_decoding.html
:link-type: url
A simple audio decoding example
@@ -46,7 +55,7 @@ We achieve these capabilities through:
.. grid-item-card:: :octicon:`file-code;1em`
GPU decoding
:img-top: _static/img/card-background.svg
- :link: generated_examples/basic_cuda_example.html
+ :link: generated_examples/decoding/basic_cuda_example.html
:link-type: url
A simple example demonstrating CUDA GPU decoding
@@ -54,7 +63,7 @@ We achieve these capabilities through:
.. grid-item-card:: :octicon:`file-code;1em`
Streaming video
:img-top: _static/img/card-background.svg
- :link: generated_examples/file_like.html
+ :link: generated_examples/decoding/file_like.html
:link-type: url
How to efficiently decode videos from the cloud
@@ -62,11 +71,24 @@ We achieve these capabilities through:
.. grid-item-card:: :octicon:`file-code;1em`
Clip sampling
:img-top: _static/img/card-background.svg
- :link: generated_examples/sampling.html
+ :link: generated_examples/decoding/sampling.html
:link-type: url
How to sample regular and random clips from a video
+Encoding
+^^^^^^^^
+
+.. grid:: 3
+
+ .. grid-item-card:: :octicon:`file-code;1em`
+ Audio Encoding
+ :img-top: _static/img/card-background.svg
+ :link: generated_examples/encoding/audio_encoding.html
+ :link-type: url
+
+ How encode audio samples
+
.. toctree::
:maxdepth: 1
:caption: TorchCodec documentation
@@ -92,4 +114,5 @@ We achieve these capabilities through:
api_ref_torchcodec
api_ref_decoders
+ api_ref_encoders
api_ref_samplers
diff --git a/examples/decoding/README.rst b/examples/decoding/README.rst
new file mode 100644
index 00000000..16381679
--- /dev/null
+++ b/examples/decoding/README.rst
@@ -0,0 +1,2 @@
+Decoding
+--------
diff --git a/examples/approximate_mode.py b/examples/decoding/approximate_mode.py
similarity index 100%
rename from examples/approximate_mode.py
rename to examples/decoding/approximate_mode.py
diff --git a/examples/audio_decoding.py b/examples/decoding/audio_decoding.py
similarity index 100%
rename from examples/audio_decoding.py
rename to examples/decoding/audio_decoding.py
diff --git a/examples/basic_cuda_example.py b/examples/decoding/basic_cuda_example.py
similarity index 100%
rename from examples/basic_cuda_example.py
rename to examples/decoding/basic_cuda_example.py
diff --git a/examples/basic_example.py b/examples/decoding/basic_example.py
similarity index 100%
rename from examples/basic_example.py
rename to examples/decoding/basic_example.py
diff --git a/examples/file_like.py b/examples/decoding/file_like.py
similarity index 99%
rename from examples/file_like.py
rename to examples/decoding/file_like.py
index a327f4c8..f0d03288 100644
--- a/examples/file_like.py
+++ b/examples/decoding/file_like.py
@@ -96,7 +96,7 @@ def bench(f, average_over=10, warmup=2):
# the :class:`~torchcodec.decoders.VideoDecoder` class to ``"approximate"``. We do
# this to avoid scanning the entire video during initialization, which would
# require downloading the entire video even if we only want to decode the first
-# frame. See :ref:`sphx_glr_generated_examples_approximate_mode.py` for more.
+# frame. See :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` for more.
def decode_from_existing_download():
diff --git a/examples/sampling.py b/examples/decoding/sampling.py
similarity index 99%
rename from examples/sampling.py
rename to examples/decoding/sampling.py
index 7cd55b8d..8fcd261e 100644
--- a/examples/sampling.py
+++ b/examples/decoding/sampling.py
@@ -61,7 +61,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# Sampling clips from a video always starts by creating a
# :class:`~torchcodec.decoders.VideoDecoder` object. If you're not already
# familiar with :class:`~torchcodec.decoders.VideoDecoder`, take a quick look
-# at: :ref:`sphx_glr_generated_examples_basic_example.py`.
+# at: :ref:`sphx_glr_generated_examples_decoding_basic_example.py`.
from torchcodec.decoders import VideoDecoder
# You can also pass a path to a local file!
diff --git a/examples/encoding/README.rst b/examples/encoding/README.rst
new file mode 100644
index 00000000..1f6fadf0
--- /dev/null
+++ b/examples/encoding/README.rst
@@ -0,0 +1,2 @@
+Encoding
+--------
diff --git a/examples/encoding/audio_encoding.py b/examples/encoding/audio_encoding.py
new file mode 100644
index 00000000..1ff88bba
--- /dev/null
+++ b/examples/encoding/audio_encoding.py
@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+========================================
+Encoding audio samples with AudioEncoder
+========================================
+
+In this example, we'll learn how to encode audio samples to a file or to raw
+bytes using the :class:`~torchcodec.encoders.AudioEncoder` class.
+"""
+
+# %%
+# Let's first generate some samples to be encoded. The data to be encoded could
+# also just come from an :class:`~torchcodec.decoders.AudioDecoder`!
+import torch
+from IPython.display import Audio as play_audio
+
+
+def make_sinewave() -> tuple[torch.Tensor, int]:
+ freq_A = 440 # Hz
+ sample_rate = 16000 # Hz
+ duration_seconds = 3 # seconds
+ t = torch.linspace(0, duration_seconds, int(sample_rate * duration_seconds), dtype=torch.float32)
+ return torch.sin(2 * torch.pi * freq_A * t), sample_rate
+
+
+samples, sample_rate = make_sinewave()
+
+print(f"Encoding samples with {samples.shape = } and {sample_rate = }")
+play_audio(samples, rate=sample_rate)
+
+# %%
+# We first instantiate an :class:`~torchcodec.encoders.AudioEncoder`. We pass it
+# the samples to be encoded. The samples must a 2D tensors of shape
+# ``(num_channels, num_samples)``, or in this case, a 1D tensor where
+# ``num_channels`` is assumed to be 1. The values must be float values
+# normalized in ``[-1, 1]``: this is also what the
+# :class:`~torchcodec.decoders.AudioDecoder` would return.
+#
+# .. note::
+#
+# The ``sample_rate`` parameter corresponds to the sample rate of the
+# *input*, not the desired encoded sample rate.
+from torchcodec.encoders import AudioEncoder
+
+encoder = AudioEncoder(samples=samples, sample_rate=sample_rate)
+
+
+# %%
+# :class:`~torchcodec.encoders.AudioEncoder` supports encoding samples into a
+# file via the :meth:`~torchcodec.encoders.AudioEncoder.to_file` method, or to
+# raw bytes via :meth:`~torchcodec.encoders.AudioEncoder.to_tensor`. For the
+# purpose of this tutorial we'll use
+# :meth:`~torchcodec.encoders.AudioEncoder.to_tensor`, so that we can easily
+# re-decode the encoded samples and check their properies. The
+# :meth:`~torchcodec.encoders.AudioEncoder.to_file` method works very similarly.
+
+encoded_samples = encoder.to_tensor(format="mp3")
+print(f"{encoded_samples.shape = }, {encoded_samples.dtype = }")
+
+
+# %%
+# That's it!
+#
+# Now that we have our encoded data, we can decode it back, to make sure it
+# looks and sounds as expected:
+from torchcodec.decoders import AudioDecoder
+
+samples_back = AudioDecoder(encoded_samples).get_all_samples()
+
+print(samples_back)
+play_audio(samples_back.data, rate=samples_back.sample_rate)
+
+# %%
+# The encoder supports some encoding options that allow you to change how to
+# data is encoded. For example, we can decide to encode our mono data (1
+# channel) into stereo data (2 channels):
+encoded_samples = encoder.to_tensor(format="wav", num_channels=2)
+
+stereo_samples_back = AudioDecoder(encoded_samples).get_all_samples()
+
+print(stereo_samples_back)
+play_audio(stereo_samples_back.data, rate=stereo_samples_back.sample_rate)
+
+# %%
+# Check the docstring of the encoding methods to learn about the different
+# encoding options.
diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h
index b12e67b9..3fa9caee 100644
--- a/src/torchcodec/_core/Encoder.h
+++ b/src/torchcodec/_core/Encoder.h
@@ -9,11 +9,6 @@ class AudioEncoder {
public:
~AudioEncoder();
- // TODO-ENCODING: document in public docs that bit_rate value is only
- // best-effort, matching to the closest supported bit_rate. I.e. passing 1 is
- // like passing 0, which results in choosing the minimum supported bit rate.
- // Passing 44_100 could result in output being 44000 if only 44000 is
- // supported.
AudioEncoder(
const torch::Tensor& samples,
// TODO-ENCODING: update this comment when we support an output sample
diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py
index 525c7ac8..cc886761 100644
--- a/src/torchcodec/_frame.py
+++ b/src/torchcodec/_frame.py
@@ -60,7 +60,7 @@ class FrameBatch(Iterable):
The ``data`` tensor is typically 4D for sequences of frames (NHWC or NCHW),
or 5D for sequences of clips, as returned by the :ref:`samplers
- `. When ``data`` is 4D (resp. 5D)
+ `. When ``data`` is 4D (resp. 5D)
the ``pts_seconds`` and ``duration_seconds`` tensors are 1D (resp. 2D).
.. note::
diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py
index 54d7e458..472a387a 100644
--- a/src/torchcodec/decoders/_audio_decoder.py
+++ b/src/torchcodec/decoders/_audio_decoder.py
@@ -26,7 +26,8 @@ class AudioDecoder:
Returned samples are float samples normalized in [-1, 1]
Args:
- source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like object): The source of the video:
+ source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like
+ object): The source of the video or audio:
- If ``str``: a local path or a URL to a video or audio file.
- If ``Pathlib.path``: a path to a local video or audio file.
@@ -34,7 +35,7 @@ class AudioDecoder:
- If file-like object: we read video data from the object on demand. The object must
expose the methods `read(self, size: int) -> bytes` and
`seek(self, offset: int, whence: int) -> bytes`. Read more in:
- :ref:`sphx_glr_generated_examples_file_like.py`.
+ :ref:`sphx_glr_generated_examples_decoding_file_like.py`.
stream_index (int, optional): Specifies which stream in the file to decode samples from.
Note that this index is absolute across all media types. If left unspecified, then
the :term:`best stream` is used.
diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py
index fff0fe9f..9b5b9f91 100644
--- a/src/torchcodec/decoders/_video_decoder.py
+++ b/src/torchcodec/decoders/_video_decoder.py
@@ -30,7 +30,7 @@ class VideoDecoder:
- If file-like object: we read video data from the object on demand. The object must
expose the methods `read(self, size: int) -> bytes` and
`seek(self, offset: int, whence: int) -> bytes`. Read more in:
- :ref:`sphx_glr_generated_examples_file_like.py`.
+ :ref:`sphx_glr_generated_examples_decoding_file_like.py`.
stream_index (int, optional): Specifies which stream in the video to decode frames from.
Note that this index is absolute across all media types. If left unspecified, then
the :term:`best stream` is used.
@@ -59,7 +59,7 @@ class VideoDecoder:
accurate as it uses the file's metadata to calculate where i
probably is. Default: "exact".
Read more about this parameter in:
- :ref:`sphx_glr_generated_examples_approximate_mode.py`
+ :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
Attributes:
diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py
index 3ad03912..b83d0ad9 100644
--- a/src/torchcodec/encoders/_audio_encoder.py
+++ b/src/torchcodec/encoders/_audio_encoder.py
@@ -8,6 +8,16 @@
class AudioEncoder:
+ """An audio encoder.
+
+ Args:
+ samples (``torch.Tensor``): The samples to encode. This must be a 2D
+ tensor of shape ``(num_channels, num_samples)``, or a 1D tensor in
+ which case ``num_channels = 1`` is assumed. Values must be float
+ values in ``[-1, 1]``.
+ sample_rate (int): The sample rate of the **input** ``samples``.
+ """
+
def __init__(self, samples: Tensor, *, sample_rate: int):
# Some of these checks are also done in C++: it's OK, they're cheap, and
# doing them here allows to surface them when the AudioEncoder is
@@ -16,8 +26,11 @@ def __init__(self, samples: Tensor, *, sample_rate: int):
raise ValueError(
f"Expected samples to be a Tensor, got {type(samples) = }."
)
+ if samples.ndim == 1:
+ # make it 2D and assume 1 channel
+ samples = samples[None, :]
if samples.ndim != 2:
- raise ValueError(f"Expected 2D samples, got {samples.shape = }.")
+ raise ValueError(f"Expected 1D or 2D samples, got {samples.shape = }.")
if samples.dtype != torch.float32:
raise ValueError(f"Expected float32 samples, got {samples.dtype = }.")
if sample_rate <= 0:
@@ -33,6 +46,20 @@ def to_file(
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> None:
+ """Encode samples into a file.
+
+ Args:
+ dest (str or ``pathlib.Path``): The path to the output file, e.g.
+ ``audio.mp3``. The extension of the file determines the audio
+ format and container.
+ bit_rate (int, optional): The output bit rate. Encoders typically
+ support a finite set of bit rate values, so ``bit_rate`` will be
+ matched to one of those supported values. The default is chosen
+ by FFmpeg.
+ num_channels (int, optional): The number of channels of the encoded
+ output samples. By default, the number of channels of the input
+ ``samples`` is used.
+ """
_core.encode_audio_to_file(
samples=self._samples,
sample_rate=self._sample_rate,
@@ -48,6 +75,22 @@ def to_tensor(
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
) -> Tensor:
+ """Encode samples into raw bytes, as a 1D uint8 Tensor.
+
+ Args:
+ format (str): The format of the encoded samples, e.g. "mp3", "wav"
+ or "flac".
+ bit_rate (int, optional): The output bit rate. Encoders typically
+ support a finite set of bit rate values, so ``bit_rate`` will be
+ matched to one of those supported values. The default is chosen
+ by FFmpeg.
+ num_channels (int, optional): The number of channels of the encoded
+ output samples. By default, the number of channels of the input
+ ``samples`` is used.
+
+ Returns:
+ Tensor: The raw encoded bytes as 1D uint8 Tensor.
+ """
return _core.encode_audio_to_tensor(
samples=self._samples,
sample_rate=self._sample_rate,
diff --git a/test/test_encoders.py b/test/test_encoders.py
index bf5f9cc6..b54a6e82 100644
--- a/test/test_encoders.py
+++ b/test/test_encoders.py
@@ -26,8 +26,8 @@ def decode(self, source) -> torch.Tensor:
def test_bad_input(self):
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
AudioEncoder(samples=123, sample_rate=32_000)
- with pytest.raises(ValueError, match="Expected 2D samples"):
- AudioEncoder(samples=torch.rand(10), sample_rate=32_000)
+ with pytest.raises(ValueError, match="Expected 1D or 2D samples"):
+ AudioEncoder(samples=torch.rand(3, 4, 5), sample_rate=32_000)
with pytest.raises(ValueError, match="Expected float32 samples"):
AudioEncoder(
samples=torch.rand(10, 10, dtype=torch.float64), sample_rate=32_000
@@ -263,3 +263,13 @@ def test_num_channels(
if num_channels_output is None:
num_channels_output = num_channels_input
assert self.decode(encoded_source).shape[0] == num_channels_output
+
+ def test_1d_samples(self):
+ # smoke test making sure 1D samples are supported
+ samples_1d, sample_rate = torch.rand(1000), 16_000
+ samples_2d = samples_1d[None, :]
+
+ torch.testing.assert_close(
+ AudioEncoder(samples_1d, sample_rate=sample_rate).to_tensor("wav"),
+ AudioEncoder(samples_2d, sample_rate=sample_rate).to_tensor("wav"),
+ )