diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 0d11b642c13..3aed2cf0a30 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -3,8 +3,16 @@ import re from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast -from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer -from torchdata.datapipes.iter import TarArchiveReader +from torchdata.datapipes.iter import ( + IterDataPipe, + LineReader, + IterKeyZipper, + Mapper, + Filter, + Demultiplexer, + TarArchiveReader, + Enumerator, +) from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, @@ -16,7 +24,6 @@ INFINITE_BUFFER_SIZE, BUILTIN_DIR, path_comparator, - Enumerator, getitem, read_mat, hint_sharding, diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index c2bc9c3cdd0..b8bd88a257d 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -6,25 +6,9 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence import torch -from torchdata.datapipes.iter import ( - IterDataPipe, - Demultiplexer, - Mapper, - Zipper, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) -from torchvision.prototype.datasets.utils._internal import ( - Decompressor, - INFINITE_BUFFER_SIZE, - hint_sharding, - hint_shuffling, -) +from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor +from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 3ed40f63ff0..a74c47e1b49 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,9 +1,4 @@ -import enum import functools -import gzip -import lzma -import os -import os.path import pathlib import pickle from typing import BinaryIO @@ -16,7 +11,6 @@ TypeVar, Iterator, Dict, - Optional, IO, Sized, ) @@ -35,11 +29,9 @@ "BUILTIN_DIR", "read_mat", "MappingIterator", - "Enumerator", "getitem", "path_accessor", "path_comparator", - "Decompressor", "read_flo", "hint_sharding", ] @@ -75,15 +67,6 @@ def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]: yield from iter(mapping.values() if self.drop_key else mapping.items()) -class Enumerator(IterDataPipe[Tuple[int, D]]): - def __init__(self, datapipe: IterDataPipe[D], start: int = 0) -> None: - self.datapipe = datapipe - self.start = start - - def __iter__(self) -> Iterator[Tuple[int, D]]: - yield from enumerate(self.datapipe, self.start) - - def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any: for item in items: obj = obj[item] @@ -123,50 +106,6 @@ def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) - return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value) -class CompressionType(enum.Enum): - GZIP = "gzip" - LZMA = "lzma" - - -class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]): - types = CompressionType - - _DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = { - types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)), - types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)), - } - - def __init__( - self, - datapipe: IterDataPipe[Tuple[str, BinaryIO]], - *, - type: Optional[Union[str, CompressionType]] = None, - ) -> None: - self.datapipe = datapipe - if isinstance(type, str): - type = self.types(type.upper()) - self.type = type - - def _detect_compression_type(self, path: str) -> CompressionType: - if self.type: - return self.type - - # TODO: this needs to be more elaborate - ext = os.path.splitext(path)[1] - if ext == ".gz": - return self.types.GZIP - elif ext == ".xz": - return self.types.LZMA - else: - raise RuntimeError("FIXME") - - def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]: - for path, file in self.datapipe: - type = self._detect_compression_type(path) - decompressor = self._DECOMPRESSORS[type] - yield path, decompressor(file) - - class PicklerDataPipe(IterDataPipe): def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None: self.source_datapipe = source_datapipe