Skip to content

Commit 709686e

Browse files
committed
Turn BytesContext into FromTensorContext
1 parent fe61d91 commit 709686e

File tree

4 files changed

+62
-96
lines changed

4 files changed

+62
-96
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 51 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9,117 +9,87 @@
99

1010
namespace facebook::torchcodec {
1111

12-
AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize)
13-
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
14-
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
15-
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
16-
createAVIOContext(&read, nullptr, &seek, &dataContext_);
17-
}
12+
namespace {
13+
14+
constexpr int64_t INITIAL_TENSOR_SIZE = 10'000'000; // 10 MB
15+
constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
16+
//
1817

1918
// The signature of this function is defined by FFMPEG.
20-
int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) {
21-
auto dataContext = static_cast<DataContext*>(opaque);
19+
int read(void* opaque, uint8_t* buf, int buf_size) {
20+
auto tensorContext = static_cast<TensorContext*>(opaque);
2221
TORCH_CHECK(
23-
dataContext->current <= dataContext->size,
22+
tensorContext->current <= tensorContext->data.numel(),
2423
"Tried to read outside of the buffer: current=",
25-
dataContext->current,
24+
tensorContext->current,
2625
", size=",
27-
dataContext->size);
26+
tensorContext->data.numel());
2827

2928
int64_t numBytesRead = std::min(
30-
static_cast<int64_t>(buf_size), dataContext->size - dataContext->current);
29+
static_cast<int64_t>(buf_size),
30+
tensorContext->data.numel() - tensorContext->current);
3131

3232
TORCH_CHECK(
3333
numBytesRead >= 0,
3434
"Tried to read negative bytes: numBytesRead=",
3535
numBytesRead,
3636
", size=",
37-
dataContext->size,
37+
tensorContext->data.numel(),
3838
", current=",
39-
dataContext->current);
39+
tensorContext->current);
4040

4141
if (numBytesRead == 0) {
4242
return AVERROR_EOF;
4343
}
4444

45-
std::memcpy(buf, dataContext->data + dataContext->current, numBytesRead);
46-
dataContext->current += numBytesRead;
45+
std::memcpy(
46+
buf,
47+
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
48+
numBytesRead);
49+
tensorContext->current += numBytesRead;
4750
return numBytesRead;
4851
}
4952

5053
// The signature of this function is defined by FFMPEG.
51-
int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
52-
auto dataContext = static_cast<DataContext*>(opaque);
53-
int64_t ret = -1;
54-
55-
switch (whence) {
56-
case AVSEEK_SIZE:
57-
ret = dataContext->size;
58-
break;
59-
case SEEK_SET:
60-
dataContext->current = offset;
61-
ret = offset;
62-
break;
63-
default:
64-
break;
65-
}
66-
67-
return ret;
68-
}
69-
70-
AVIOToTensorContext::AVIOToTensorContext()
71-
: dataContext_{
72-
torch::empty(
73-
{AVIOToTensorContext::INITIAL_TENSOR_SIZE},
74-
{torch::kUInt8}),
75-
0} {
76-
createAVIOContext(nullptr, &write, &seek, &dataContext_);
77-
}
78-
79-
// The signature of this function is defined by FFMPEG.
80-
int AVIOToTensorContext::write(void* opaque, const uint8_t* buf, int buf_size) {
81-
auto dataContext = static_cast<DataContext*>(opaque);
54+
int write(void* opaque, const uint8_t* buf, int buf_size) {
55+
auto tensorContext = static_cast<TensorContext*>(opaque);
8256

8357
int64_t bufSize = static_cast<int64_t>(buf_size);
84-
if (dataContext->current + bufSize > dataContext->outputTensor.numel()) {
58+
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
8559
TORCH_CHECK(
86-
dataContext->outputTensor.numel() * 2 <=
87-
AVIOToTensorContext::MAX_TENSOR_SIZE,
60+
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
8861
"We tried to allocate an output encoded tensor larger than ",
89-
AVIOToTensorContext::MAX_TENSOR_SIZE,
62+
MAX_TENSOR_SIZE,
9063
" bytes. If you think this should be supported, please report.");
9164

9265
// We double the size of the outpout tensor. Calling cat() may not be the
9366
// most efficient, but it's simple.
94-
dataContext->outputTensor =
95-
torch::cat({dataContext->outputTensor, dataContext->outputTensor});
67+
tensorContext->data =
68+
torch::cat({tensorContext->data, tensorContext->data});
9669
}
9770

9871
TORCH_CHECK(
99-
dataContext->current + bufSize <= dataContext->outputTensor.numel(),
72+
tensorContext->current + bufSize <= tensorContext->data.numel(),
10073
"Re-allocation of the output tensor didn't work. ",
10174
"This should not happen, please report on TorchCodec bug tracker");
10275

103-
uint8_t* outputTensorData = dataContext->outputTensor.data_ptr<uint8_t>();
104-
std::memcpy(outputTensorData + dataContext->current, buf, bufSize);
105-
dataContext->current += bufSize;
76+
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
77+
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
78+
tensorContext->current += bufSize;
10679
return buf_size;
10780
}
10881

10982
// The signature of this function is defined by FFMPEG.
110-
// Note: This `seek()` implementation is very similar to that of
111-
// AVIOBytesContext. We could consider merging both classes, or do some kind of
112-
// refac, but this doesn't seem worth it ATM.
113-
int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) {
114-
auto dataContext = static_cast<DataContext*>(opaque);
83+
int64_t seek(void* opaque, int64_t offset, int whence) {
84+
auto tensorContext = static_cast<TensorContext*>(opaque);
11585
int64_t ret = -1;
11686

11787
switch (whence) {
11888
case AVSEEK_SIZE:
119-
ret = dataContext->outputTensor.numel();
89+
ret = tensorContext->data.numel();
12090
break;
12191
case SEEK_SET:
122-
dataContext->current = offset;
92+
tensorContext->current = offset;
12393
ret = offset;
12494
break;
12595
default:
@@ -129,9 +99,24 @@ int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) {
12999
return ret;
130100
}
131101

102+
} // namespace
103+
104+
AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
105+
: tensorContext_{data, 0} {
106+
TORCH_CHECK(data.numel() > 0, "data must not be empty");
107+
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
108+
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
109+
createAVIOContext(&read, nullptr, &seek, &tensorContext_);
110+
}
111+
112+
AVIOToTensorContext::AVIOToTensorContext()
113+
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
114+
createAVIOContext(nullptr, &write, &seek, &tensorContext_);
115+
}
116+
132117
torch::Tensor AVIOToTensorContext::getOutputTensor() {
133-
return dataContext_.outputTensor.narrow(
134-
/*dim=*/0, /*start=*/0, /*length=*/dataContext_.current);
118+
return tensorContext_.data.narrow(
119+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
135120
}
136121

137122
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,19 @@
1111

1212
namespace facebook::torchcodec {
1313

14+
struct TensorContext {
15+
torch::Tensor data;
16+
int64_t current;
17+
};
18+
1419
// For Decoding: enables users to pass in the entire video or audio as bytes.
1520
// Our read and seek functions then traverse the bytes in memory.
16-
class AVIOBytesContext : public AVIOContextHolder {
21+
class AVIOFromTensorContext : public AVIOContextHolder {
1722
public:
18-
explicit AVIOBytesContext(const void* data, int64_t dataSize);
23+
explicit AVIOFromTensorContext(torch::Tensor data);
1924

2025
private:
21-
struct DataContext {
22-
const uint8_t* data;
23-
int64_t size;
24-
int64_t current;
25-
};
26-
27-
static int read(void* opaque, uint8_t* buf, int buf_size);
28-
static int64_t seek(void* opaque, int64_t offset, int whence);
29-
30-
DataContext dataContext_;
26+
TensorContext tensorContext_;
3127
};
3228

3329
// For Encoding: used to encode into an output uint8 (bytes) tensor.
@@ -37,18 +33,7 @@ class AVIOToTensorContext : public AVIOContextHolder {
3733
torch::Tensor getOutputTensor();
3834

3935
private:
40-
struct DataContext {
41-
torch::Tensor outputTensor;
42-
int64_t current;
43-
};
44-
45-
static constexpr int64_t INITIAL_TENSOR_SIZE = 10'000'000; // 10MB
46-
static constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
47-
static int write(void* opaque, const uint8_t* buf, int buf_size);
48-
// We need to expose seek() for some formats like mp3.
49-
static int64_t seek(void* opaque, int64_t offset, int whence);
50-
51-
DataContext dataContext_;
36+
TensorContext tensorContext_;
5237
};
5338

5439
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,13 @@ at::Tensor create_from_tensor(
196196
TORCH_CHECK(
197197
video_tensor.scalar_type() == torch::kUInt8,
198198
"video_tensor must be kUInt8");
199-
void* data = video_tensor.mutable_data_ptr();
200-
size_t length = video_tensor.numel();
201199

202200
SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
203201
if (seek_mode.has_value()) {
204202
realSeek = seekModeFromString(seek_mode.value());
205203
}
206204

207-
auto contextHolder = std::make_unique<AVIOBytesContext>(data, length);
205+
auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
208206

209207
std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
210208
std::make_unique<SingleStreamDecoder>(std::move(contextHolder), realSeek);

test/test_decoders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def seek(self, offset: int, whence: int) -> bytes:
9393
decoder = Decoder(source)
9494
assert isinstance(decoder.metadata, _core._metadata.StreamMetadata)
9595

96-
9796
@pytest.mark.parametrize("Decoder", (VideoDecoder, AudioDecoder))
9897
def test_create_fails(self, Decoder):
9998
with pytest.raises(TypeError, match="Unknown source type"):
@@ -139,10 +138,9 @@ def test_create_bytes_ownership(self):
139138
decoder = VideoDecoder(f.read())
140139

141140
assert decoder[0] is not None
142-
assert decoder[len(decoder)//2] is not None
141+
assert decoder[len(decoder) // 2] is not None
143142
assert decoder[-1] is not None
144143

145-
146144
def test_create_fails(self):
147145
with pytest.raises(ValueError, match="Invalid seek mode"):
148146
VideoDecoder(NASA_VIDEO.path, seek_mode="blah")

0 commit comments

Comments
 (0)