diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0946fd36be1..aa1652353f4 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -9,7 +9,6 @@ import pathlib import pickle import random -import unittest.mock import xml.etree.ElementTree as ET from collections import defaultdict, Counter @@ -21,7 +20,6 @@ from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype import datasets -from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -66,17 +64,17 @@ def prepare(self, home, config): mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) - with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): - required_file_names = { - resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() - } - available_file_names = {path.name for path in root.glob("*")} - missing_file_names = required_file_names - available_file_names - if missing_file_names: - raise pytest.UsageError( - f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - f"for {config}, but they were not created by the mock data function." - ) + # with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): + # required_file_names = { + # resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() + # } + # available_file_names = {path.name for path in root.glob("*")} + # missing_file_names = required_file_names - available_file_names + # if missing_file_names: + # raise pytest.UsageError( + # f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " + # f"for {config}, but they were not created by the mock data function." + # ) return mock_info diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index c7dff541dbe..daa446a31c3 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -10,6 +10,7 @@ from torch.utils.data.graph import traverse from torchdata.datapipes.iter import Shuffler, ShardingFilter from torchvision.prototype import transforms, datasets +from torchvision.prototype.datasets.utils._internal import TakerDataPipe from torchvision.prototype.utils._internal import sequence_to_str @@ -20,7 +21,7 @@ @pytest.fixture def test_home(mocker, tmp_path): - mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) + mocker.patch("torchvision.prototype.datasets.utils._internal.home", return_value=str(tmp_path)) yield tmp_path @@ -51,8 +52,10 @@ def test_smoke(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) - if not isinstance(dataset, datasets.utils.Dataset2): - raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") + if not isinstance(dataset, TakerDataPipe): + raise AssertionError( + f"Loading the dataset should return an TakerDataPipe, but got {type(dataset)} instead." + ) @parametrize_dataset_mocks(DATASET_MOCKS) def test_sample(self, test_home, dataset_mock, config): @@ -100,7 +103,6 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) def test_serializable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) @@ -109,7 +111,6 @@ def test_serializable(self, test_home, dataset_mock, config): pickle.dumps(dataset) - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 44c66e422f2..5549f30e416 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,8 +7,8 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error +from ._home import home # usort: skip from . import utils -from ._home import home # Load this last, since some parts depend on the above being loaded first from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8f8bb53deb4..3e636524202 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,13 +1,11 @@ import pathlib -from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar +from typing import Any, Dict, List, Callable, Optional, Union, TypeVar -from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset2 +from torchvision.prototype.datasets.utils._internal import TakerDataPipe from torchvision.prototype.utils._internal import add_suggestion T = TypeVar("T") -D = TypeVar("D", bound=Type[Dataset2]) BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} @@ -23,10 +21,12 @@ def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: BUILTIN_DATASETS = {} -def register_dataset(name: str) -> Callable[[D], D]: - def wrapper(dataset_cls: D) -> D: - BUILTIN_DATASETS[name] = dataset_cls - return dataset_cls +def register_dataset( + name: Optional[str] = None, +) -> Callable[[Callable[..., TakerDataPipe]], Callable[..., TakerDataPipe]]: + def wrapper(dataset_fn: Callable[..., TakerDataPipe]) -> Callable[..., TakerDataPipe]: + BUILTIN_DATASETS[name or dataset_fn.__name__] = dataset_fn + return dataset_fn return wrapper @@ -56,10 +56,6 @@ def info(name: str) -> Dict[str, Any]: return find(BUILTIN_INFOS, name) -def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2: +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> TakerDataPipe: dataset_cls = find(BUILTIN_DATASETS, name) - - if root is None: - root = pathlib.Path(home()) / name - return dataset_cls(root, **config) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 1567ef29811..a3136464059 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -9,7 +9,7 @@ from .eurosat import EuroSAT from .fer2013 import FER2013 from .gtsrb import GTSRB -from .imagenet import ImageNet +from .imagenet import imagenet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .oxford_iiit_pet import OxfordIITPet from .pcam import PCAM diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 6f91d4c4a8d..4adbd3c9b44 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,14 +1,20 @@ +import functools import pathlib import re from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union -from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer -from torchdata.datapipes.iter import TarArchiveReader +from torchdata.datapipes.iter import ( + IterDataPipe, + LineReader, + IterKeyZipper, + Mapper, + Filter, + Demultiplexer, + TarArchiveReader, +) from torchvision.prototype.datasets.utils import ( DatasetInfo, - OnlineResource, ManualDownloadResource, - Dataset2, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -20,19 +26,23 @@ hint_sharding, hint_shuffling, path_accessor, + TakerDataPipe, + verify_str_arg, + get_root, ) from torchvision.prototype.features import Label, EncodedImage from .._api import register_dataset, register_info - NAME = "imagenet" +CATEGORIES, WNIDS = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) +WNID_TO_CATEGORY = dict(zip(WNIDS, CATEGORIES)) + @register_info(NAME) -def _info() -> Dict[str, Any]: - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - return dict(categories=categories, wnids=wnids) +def info() -> Dict[str, Any]: + return dict(categories=CATEGORIES, wnids=WNIDS) class ImageNetResource(ManualDownloadResource): @@ -40,156 +50,155 @@ def __init__(self, **kwargs: Any) -> None: super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) +def load_images_dp(root: pathlib.Path, *, split: str, **kwargs: Any) -> IterDataPipe[Tuple[str, BinaryIO]]: + name = "test_v10102019" if split == "test" else split + return ImageNetResource( + file_name=f"ILSVRC2012_img_{name}.tar", + sha256={ + "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", + "val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0", + "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", + }[name], + ).load(root, **kwargs) + + +def load_devkit_dp(root: pathlib.Path, **kwargs: Any) -> IterDataPipe[Tuple[str, BinaryIO]]: + return ImageNetResource( + file_name="ILSVRC2012_devkit_t12.tar.gz", + sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", + ).load(root, **kwargs) + + +TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") + + +def prepare_train_data(data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: + path = pathlib.Path(data[0]) + wnid = cast(Match[str], TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] + label = Label.from_category(WNID_TO_CATEGORY[wnid], categories=CATEGORIES) + return (label, wnid), data + + +def prepare_test_data(data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: + return None, data + + +def classifiy_devkit(data: Tuple[str, BinaryIO]) -> Optional[int]: + return { + "meta.mat": 0, + "ILSVRC2012_validation_ground_truth.txt": 1, + }.get(pathlib.Path(data[0]).name) + + +# Although the WordNet IDs (wnids) are unique, the corresponding human-readable categories are not. For example, both +# 'n02012849' and 'n03126707' are labeled 'crane' while the first means the bird and the latter means the construction +# equipment. +WNID_MAP = { + "n03126707": "construction crane", + "n03710721": "tank suit", +} + + +def extract_categories_and_wnids(data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: + synsets = read_mat(data[1], squeeze_me=True)["synsets"] + return [ + (WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) + for _, wnid, category, _, num_children, *_ in synsets + # if num_children > 0, we are looking at a superclass that has no direct instance + if num_children == 0 + ] + + +def imagenet_label_to_wnid(imagenet_label: str, *, wnids: List[str]) -> str: + return wnids[int(imagenet_label) - 1] + + +VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") + + +def val_test_image_key(path: pathlib.Path) -> int: + return int(VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] + + +def prepare_val_data( + data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] +) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: + label_data, image_data = data + _, wnid = label_data + label = Label.from_category(WNID_TO_CATEGORY[wnid], categories=CATEGORIES) + return (label, wnid), image_data + + +def prepare_sample( + data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], +) -> Dict[str, Any]: + label_data, (path, buffer) = data + + return dict( + dict(zip(("label", "wnid"), label_data if label_data else (None, None))), + path=path, + image=EncodedImage.from_file(buffer), + ) + + @register_dataset(NAME) -class ImageNet(Dataset2): - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: - self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - - info = _info() - categories, wnids = info["categories"], info["wnids"] - self._categories: List[str] = categories - self._wnids: List[str] = wnids - self._wnid_to_category = dict(zip(wnids, categories)) - - super().__init__(root) - - _IMAGES_CHECKSUMS = { - "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", - "val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0", - "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", - } - - def _resources(self) -> List[OnlineResource]: - name = "test_v10102019" if self._split == "test" else self._split - images = ImageNetResource( - file_name=f"ILSVRC2012_img_{name}.tar", - sha256=self._IMAGES_CHECKSUMS[name], +def imagenet(root: Optional[Union[str, pathlib.Path]] = None, *, split: str = "train", **kwargs: Any) -> TakerDataPipe: + root = get_root(root, NAME) + verify_str_arg(split, "split", ["train", "val", "test"]) + + images_dp = load_images_dp(root, split=split, **kwargs) + if split == "train": + # the train archive is a tar of tars + images_dp = TarArchiveReader(images_dp) + images_dp = hint_sharding(images_dp) + images_dp = hint_shuffling(images_dp) + dp = Mapper(images_dp, prepare_train_data) + elif split == "test": + images_dp = hint_sharding(images_dp) + images_dp = hint_shuffling(images_dp) + dp = Mapper(images_dp, prepare_test_data) + else: # split == "val" + devkit_dp = load_devkit_dp(root, **kwargs) + + meta_dp, label_dp = Demultiplexer( + devkit_dp, 2, classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - resources: List[OnlineResource] = [images] - - if self._split == "val": - devkit = ImageNetResource( - file_name="ILSVRC2012_devkit_t12.tar.gz", - sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", - ) - resources.append(devkit) - - return resources - - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - - def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - path = pathlib.Path(data[0]) - wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), data - - def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: - return None, data - - def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: - return { - "meta.mat": 0, - "ILSVRC2012_validation_ground_truth.txt": 1, - }.get(pathlib.Path(data[0]).name) - - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } - - def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: - synsets = read_mat(data[1], squeeze_me=True)["synsets"] - return [ - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) - for _, wnid, category, _, num_children, *_ in synsets - # if num_children > 0, we are looking at a superclass that has no direct instance - if num_children == 0 - ] - - def _imagenet_label_to_wnid(self, imagenet_label: str) -> str: - return self._wnids[int(imagenet_label) - 1] - - _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") - - def _val_test_image_key(self, path: pathlib.Path) -> int: - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] - - def _prepare_val_data( - self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] - ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - label_data, image_data = data - _, wnid = label_data - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), image_data - - def _prepare_sample( - self, - data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - label_data, (path, buffer) = data - - return dict( - dict(zip(("label", "wnid"), label_data if label_data else (None, None))), - path=path, - image=EncodedImage.from_file(buffer), + + meta_dp = Mapper(meta_dp, extract_categories_and_wnids) + _, wnids = zip(*next(iter(meta_dp))) + + label_dp = LineReader(label_dp, decode=True, return_path=False) + label_dp = Mapper(label_dp, functools.partial(imagenet_label_to_wnid, wnids=wnids)) + label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) + label_dp = hint_sharding(label_dp) + label_dp = hint_shuffling(label_dp) + + dp = IterKeyZipper( + label_dp, + images_dp, + key_fn=getitem(0), + ref_key_fn=path_accessor(val_test_image_key), + buffer_size=INFINITE_BUFFER_SIZE, ) + dp = Mapper(dp, prepare_val_data) - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - if self._split in {"train", "test"}: - dp = resource_dps[0] - - # the train archive is a tar of tars - if self._split == "train": - dp = TarArchiveReader(dp) - - dp = hint_sharding(dp) - dp = hint_shuffling(dp) - dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) - else: # config.split == "val": - images_dp, devkit_dp = resource_dps - - meta_dp, label_dp = Demultiplexer( - devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE - ) - - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) - _, wnids = zip(*next(iter(meta_dp))) - - label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, self._imagenet_label_to_wnid) - label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) - label_dp = hint_sharding(label_dp) - label_dp = hint_shuffling(label_dp) - - dp = IterKeyZipper( - label_dp, - images_dp, - key_fn=getitem(0), - ref_key_fn=path_accessor(self._val_test_image_key), - buffer_size=INFINITE_BUFFER_SIZE, - ) - dp = Mapper(dp, self._prepare_val_data) - - return Mapper(dp, self._prepare_sample) - - def __len__(self) -> int: - return { + dp = Mapper(dp, prepare_sample) + return TakerDataPipe( + dp, + num_take={ "train": 1_281_167, "val": 50_000, "test": 100_000, - }[self._split] + }[split], + ) + - def _generate_categories(self) -> List[Tuple[str, ...]]: - self._split = "val" - resources = self._resources() +def generate_categories(root: pathlib.Path, **kwargs: Any) -> List[Tuple[str, ...]]: + devkit_dp = load_devkit_dp(root, **kwargs) - devkit_dp = resources[1].load(self._root) - meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) + meta_dp = Mapper(meta_dp, extract_categories_and_wnids) - categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) - categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) - return categories_and_wnids + categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) + categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) + return categories_and_wnids diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 3ed40f63ff0..3ca578753e4 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -27,23 +27,11 @@ import torch.utils.data from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler from torchdata.datapipes.utils import StreamWrapper +from torchvision.datasets.utils import verify_str_arg # noqa: F401 +from torchvision.prototype.datasets import home from torchvision.prototype.utils._internal import fromfile -__all__ = [ - "INFINITE_BUFFER_SIZE", - "BUILTIN_DIR", - "read_mat", - "MappingIterator", - "Enumerator", - "getitem", - "path_accessor", - "path_comparator", - "Decompressor", - "read_flo", - "hint_sharding", -] - K = TypeVar("K") D = TypeVar("D") @@ -258,3 +246,10 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE) + + +def get_root(root: Optional[Union[str, pathlib.Path]], name: str) -> pathlib.Path: + if root is None: + return pathlib.Path(home()) / name + else: + return pathlib.Path(root) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index b2ae175c551..d91a8b19aa6 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,7 +2,7 @@ import hashlib import itertools import pathlib -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn +from typing import Optional, Sequence, Tuple, Callable, BinaryIO, Any, Union, NoReturn from urllib.parse import urlparse from torchdata.datapipes.iter import ( @@ -55,7 +55,7 @@ def _extract(file: pathlib.Path) -> pathlib.Path: def _decompress(file: pathlib.Path) -> pathlib.Path: return pathlib.Path(_decompress(str(file), remove_finished=True)) - def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: + def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, BinaryIO]]: if path.is_dir(): return FileOpener(FileLister(str(path), recursive=True), mode="rb") @@ -75,7 +75,7 @@ def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: def _guess_archive_loader( self, path: pathlib.Path - ) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]: + ) -> Optional[Callable[[IterDataPipe[Tuple[str, BinaryIO]]], IterDataPipe[Tuple[str, BinaryIO]]]]: try: _, archive_type, _ = _detect_file_type(path.name) except RuntimeError: @@ -84,7 +84,7 @@ def _guess_archive_loader( def load( self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False - ) -> IterDataPipe[Tuple[str, IO]]: + ) -> IterDataPipe[Tuple[str, BinaryIO]]: root = pathlib.Path(root) path = root / self.file_name # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories