Skip to content

Commit 3be12c7

Browse files
NicolasHugpmeier
andauthored
Cleanup prototype dataset implementation (#5774)
* Remove Dataset2 class * Move read_categories_file out of DatasetInfo * Remove FrozenBunch and FrozenMapping * Remove test_prototype_datasets_api.py and move missing dep test somewhere else * ufmt * Let read_categories_file accept names instead of paths * Mypy * flake8 * fix category file reading Co-authored-by: Philip Meier <[email protected]>
1 parent 5062a32 commit 3be12c7

31 files changed

+121
-607
lines changed

test/builtin_dataset_mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def prepare(self, home, config):
6868

6969
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
7070

71-
with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
71+
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
7272
required_file_names = {
7373
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
7474
}

test/test_prototype_builtin_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_smoke(self, test_home, dataset_mock, config):
5959

6060
dataset = datasets.load(dataset_mock.name, **config)
6161

62-
if not isinstance(dataset, datasets.utils.Dataset2):
62+
if not isinstance(dataset, datasets.utils.Dataset):
6363
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
6464

6565
@parametrize_dataset_mocks(DATASET_MOCKS)

test/test_prototype_datasets_api.py

Lines changed: 0 additions & 231 deletions
This file was deleted.

test/test_prototype_datasets_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from datasets_utils import make_fake_flo_file
77
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
8-
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
8+
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
99
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
1010

1111

@@ -101,3 +101,21 @@ def preprocess_sentinel(path):
101101
assert redirected_resource.file_name == file_name
102102
assert redirected_resource.sha256 == sha256_sentinel
103103
assert redirected_resource._preprocess is preprocess_sentinel
104+
105+
106+
def test_missing_dependency_error():
107+
class DummyDataset(Dataset):
108+
def __init__(self):
109+
super().__init__(root="root", dependencies=("fake_dependency",))
110+
111+
def _resources(self):
112+
pass
113+
114+
def _datapipe(self, resource_dps):
115+
pass
116+
117+
def __len__(self):
118+
pass
119+
120+
with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
121+
DummyDataset()

torchvision/prototype/datasets/_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar
33

44
from torchvision.prototype.datasets import home
5-
from torchvision.prototype.datasets.utils import Dataset2
5+
from torchvision.prototype.datasets.utils import Dataset
66
from torchvision.prototype.utils._internal import add_suggestion
77

88

99
T = TypeVar("T")
10-
D = TypeVar("D", bound=Type[Dataset2])
10+
D = TypeVar("D", bound=Type[Dataset])
1111

1212
BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}
1313

@@ -56,7 +56,7 @@ def info(name: str) -> Dict[str, Any]:
5656
return find(BUILTIN_INFOS, name)
5757

5858

59-
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2:
59+
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
6060
dataset_cls = find(BUILTIN_DATASETS, name)
6161

6262
if root is None:

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,26 @@
99
Filter,
1010
IterKeyZipper,
1111
)
12-
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
12+
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
1313
from torchvision.prototype.datasets.utils._internal import (
1414
INFINITE_BUFFER_SIZE,
1515
read_mat,
1616
hint_sharding,
1717
hint_shuffling,
18-
BUILTIN_DIR,
18+
read_categories_file,
1919
)
2020
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
2121

2222
from .._api import register_dataset, register_info
2323

2424

25-
CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories"))
26-
27-
2825
@register_info("caltech101")
2926
def _caltech101_info() -> Dict[str, Any]:
30-
return dict(categories=CALTECH101_CATEGORIES)
27+
return dict(categories=read_categories_file("caltech101"))
3128

3229

3330
@register_dataset("caltech101")
34-
class Caltech101(Dataset2):
31+
class Caltech101(Dataset):
3532
"""
3633
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
3734
- **dependencies**:
@@ -152,16 +149,13 @@ def _generate_categories(self) -> List[str]:
152149
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
153150

154151

155-
CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories"))
156-
157-
158152
@register_info("caltech256")
159153
def _caltech256_info() -> Dict[str, Any]:
160-
return dict(categories=CALTECH256_CATEGORIES)
154+
return dict(categories=read_categories_file("caltech256"))
161155

162156

163157
@register_dataset("caltech256")
164-
class Caltech256(Dataset2):
158+
class Caltech256(Dataset):
165159
"""
166160
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
167161
"""

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
IterKeyZipper,
1111
)
1212
from torchvision.prototype.datasets.utils import (
13-
Dataset2,
13+
Dataset,
1414
GDriveResource,
1515
OnlineResource,
1616
)
@@ -68,7 +68,7 @@ def _info() -> Dict[str, Any]:
6868

6969

7070
@register_dataset(NAME)
71-
class CelebA(Dataset2):
71+
class CelebA(Dataset):
7272
"""
7373
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
7474
"""

0 commit comments

Comments
 (0)