Skip to content

Refactor prototype datasets #5778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1916bd7
Merge branch 'main' into prototype-datasets-inheritance
pmeier Feb 24, 2022
0aae427
refactor prototype datasets to inherit from IterDataPipe (#5448)
pmeier Feb 24, 2022
d5c3a43
Merge branch 'main' into prototype-datasets-inheritance
pmeier Mar 8, 2022
e7e921e
Merge branch 'main'
pmeier Mar 29, 2022
9acb8f9
Merge branch 'main'
pmeier Mar 31, 2022
c5f6c11
fix imagenet
pmeier Mar 31, 2022
49655b2
Merge branch 'main'
pmeier Apr 4, 2022
9f12ef4
fix prototype datasets data loading tests (#5711)
pmeier Apr 5, 2022
b514955
Merge branch 'main' into prototype-datasets-inheritance
pmeier Apr 5, 2022
aca4164
migrate VOC prototype dataset (#5743)
pmeier Apr 5, 2022
dead87d
migrate CIFAR prototype datasets (#5751)
pmeier Apr 6, 2022
6a0592f
migrate country211 prototype dataset (#5753)
pmeier Apr 6, 2022
2ed549d
migrate CLEVR prototype datsaet (#5752)
pmeier Apr 6, 2022
42bc682
migrate coco prototype (#5473)
pmeier Apr 6, 2022
27104fe
Migrate PCAM prototype dataset (#5745)
NicolasHug Apr 6, 2022
291be31
Migrate DTD prototype dataset (#5757)
NicolasHug Apr 6, 2022
217616b
Migrate GTSRB prototype dataset (#5746)
NicolasHug Apr 6, 2022
2612c4c
migrate CelebA prototype dataset (#5750)
pmeier Apr 6, 2022
6de6ec4
Migrate Food101 prototype dataset (#5758)
NicolasHug Apr 6, 2022
ebe9006
Migrate Fer2013 prototype dataset (#5759)
NicolasHug Apr 6, 2022
8194b17
Migrate EuroSAT prototype dataset (#5760)
NicolasHug Apr 6, 2022
4c9cbab
Migrate Semeion prototype dataset (#5761)
NicolasHug Apr 6, 2022
5cd5722
migrate caltech prototype datasets (#5749)
pmeier Apr 6, 2022
70cd406
Migrate Oxford Pets prototype dataset (#5764)
NicolasHug Apr 6, 2022
ccfcaa5
migrate mnist prototype datasets (#5480)
pmeier Apr 6, 2022
9ea341a
Migrate Stanford Cars prototype dataset (#5767)
NicolasHug Apr 6, 2022
3b10147
fix category file generation (#5770)
pmeier Apr 6, 2022
1691e72
migrate cub200 prototype dataset (#5765)
pmeier Apr 6, 2022
2a212b8
Migrate USPS prototype dataset (#5771)
NicolasHug Apr 6, 2022
0b66ed6
migrate SBD prototype dataset (#5772)
pmeier Apr 6, 2022
b3c8384
Migrate SVHN prototype dataset (#5769)
NicolasHug Apr 6, 2022
fb56882
Merge branch 'main' into prototype-datasets-inheritance
pmeier Apr 6, 2022
1199144
add test to enforce __len__ is working on prototype datasets (#5742)
pmeier Apr 6, 2022
8e7987a
reactivate special dataset tests
pmeier Apr 6, 2022
5062a32
add missing annotation
pmeier Apr 6, 2022
3be12c7
Cleanup prototype dataset implementation (#5774)
NicolasHug Apr 7, 2022
4c73a5e
Merge branch 'main' into prototype-datasets-inheritance
pmeier Apr 7, 2022
cd36d06
update prototype dataset README (#5777)
pmeier Apr 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 184 additions & 145 deletions test/builtin_dataset_mocks.py

Large diffs are not rendered by default.

70 changes: 46 additions & 24 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
from torchdata.datapipes.iter import Shuffler, ShardingFilter
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
Expand Down Expand Up @@ -42,14 +43,24 @@ def test_coverage():

@pytest.mark.filterwarnings("error")
class TestCommon:
@pytest.mark.parametrize("name", datasets.list_datasets())
def test_info(self, name):
try:
info = datasets.info(name)
except ValueError:
raise AssertionError("No info available.") from None

if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())):
raise AssertionError("Info should be a dictionary with string keys.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config):
Expand All @@ -76,24 +87,7 @@ def test_num_samples(self, test_home, dataset_mock, config):

dataset = datasets.load(dataset_mock.name, **config)

num_samples = 0
for _ in dataset:
num_samples += 1

assert num_samples == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)
assert len(list(dataset)) == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
Expand All @@ -116,14 +110,36 @@ def test_transformable(self, test_home, dataset_mock, config):

next(iter(dataset.map(transforms.Identity())))

@pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

traverse(dataset, only_datapipe=only_datapipe)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

pickle.dumps(dataset)

@pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

dl = DataLoader(
dataset,
batch_size=2,
num_workers=num_workers,
collate_fn=lambda batch: batch,
)

next(iter(dl))

# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
Expand All @@ -132,7 +148,6 @@ def test_serializable(self, test_home, dataset_mock, config):
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):

dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
Expand Down Expand Up @@ -160,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
# resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

assert len(dataset) > 0


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
Expand All @@ -186,7 +208,7 @@ class TestGTSRB:
def test_label_matches_path(self, test_home, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config.split != "train":
if config["split"] != "train":
return

dataset_mock.prepare(test_home, config)
Expand Down
231 changes: 0 additions & 231 deletions test/test_prototype_datasets_api.py

This file was deleted.

20 changes: 19 additions & 1 deletion test/test_prototype_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile


Expand Down Expand Up @@ -101,3 +101,21 @@ def preprocess_sentinel(path):
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel


def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))

def _resources(self):
pass

def _datapipe(self, resource_dps):
pass

def __len__(self):
pass

with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
Loading