Skip to content

add LSUN prototype dataset #5390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import pathlib
import pickle
import random
import string
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter

import numpy as np
import PIL.Image
import pytest
import torch
from common_utils import get_tmp_dir
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
Expand Down Expand Up @@ -1340,3 +1342,47 @@ def pcam(info, root, config):
compressed_file.write(compressed_data)

return num_images


@register_mock
def lsun(info, root, config):
def make_lmdb(path):
import lmdb

hexdigits_lowercase = string.digits + string.ascii_lowercase[:6]

num_samples = torch.randint(1, 4, size=()).item()
format = "png"

with get_tmp_dir() as tmp_dir:
files = create_image_folder(tmp_dir, "tmp", lambda idx: f"{idx}.{format}", num_samples)

values = []
for file in files:
buffer = io.BytesIO()
PIL.Image.open(file).save(buffer, format)
buffer.seek(0)
values.append(buffer.read())

with lmdb.open(str(path)) as env, env.begin(write=True) as txn:
for value in values:
key = "".join(random.choice(hexdigits_lowercase) for _ in range(40)).encode()
txn.put(key, value)

return num_samples

if config.split == "test":
names = ["test_lmdb"]
else:
names = [f"{category}_{config.split}_lmdb" for category in info.categories]

num_samples = 0
for name in names:
data_folder = root / name
data_folder.mkdir()

num_samples += make_lmdb(data_folder)

make_zip(root, data_folder.with_suffix(".zip").name)

return num_samples
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .fer2013 import FER2013
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .lsun import Lsun
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
Expand Down
186 changes: 186 additions & 0 deletions torchvision/prototype/datasets/_builtin/lsun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import functools
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator

import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, OnDiskCacheHolder, Concater, IterableWrapper
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import Label

# We need lmdb.Environment as annotation, but lmdb is an optional requirement at import
try:
import lmdb

Environment = lmdb.Environment
except ImportError:
Environment = Any


class LmdbKeyExtractor(IterDataPipe[Tuple[str, bytes]]):
def __init__(self, datapipe: IterDataPipe[str]) -> None:
self.datapipe = datapipe

def __iter__(self) -> Iterator[Tuple[str, bytes]]:
import lmdb

for path in self.datapipe:
with lmdb.open(path, readonly=True) as env:
with env.begin(write=False) as txn:
keys = b"\n".join(key for key in txn.cursor().iternext(keys=True, values=False))
yield path, keys


class LmdbLoader(IterDataPipe[Tuple[Environment, bytes]]):
def __init__(self, datapipe: IterDataPipe[str]) -> None:
self.datapipe = datapipe

def __iter__(self) -> Iterator[Tuple[Environment, bytes]]: # type: ignore[valid-type]
import lmdb

for cache_path in self.datapipe:
env = lmdb.open(str(pathlib.Path(cache_path).parent), readonly=True)

with open(cache_path, "rb") as file:
for key in file:
yield env, key.strip()


class LmdbReader(IterDataPipe):
def __init__(self, datapipe: IterDataPipe[Tuple[Environment, bytes]]):
self.datapipe = datapipe

def __iter__(self) -> Iterator[Tuple[str, bytes, io.BytesIO]]:
for env, key in self.datapipe:
with env.begin(write=False) as txn:
yield env.path(), key, io.BytesIO(txn.get(key))


class LsunHttpResource(HttpResource):
def __init__(self, *args: Any, extract: bool = True, **kwargs: Any) -> None:
super().__init__(*args, extract=extract, **kwargs)

def _loader(self, path: pathlib.Path) -> IterDataPipe[str]:
# LMDB datasets cannot be loaded through an open file handle, but have to be loaded through the path of the
# parent directory.
return IterableWrapper([str(next(path.rglob("data.mdb")).parent)])
Comment on lines +69 to +72
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug this would have been supported with the loader parameter we removed in #5282. But I guess since this is (currently) the only case that needs this, we can special case it.



class Lsun(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"lsun",
type=DatasetType.IMAGE,
categories=(
"bedroom",
"bridge",
"church_outdoor",
"classroom",
"conference_room",
"dining_room",
"kitchen",
"living_room",
"restaurant",
"tower",
),
valid_options=dict(split=("train", "val", "test")),
dependencies=("lmdb",),
homepage="https://www.yf.io/p/lsun",
)

_CHECKSUMS = {
("train", "bedroom"): "",
("train", "bridge"): "",
("train", "church_outdoor"): "",
("train", "classroom"): "",
("train", "conference_room"): "",
("train", "dining_room"): "",
("train", "kitchen"): "",
("train", "living_room"): "",
("train", "restaurant"): "",
("train", "tower"): "",
("val", "bedroom"): "5d022e781b241c25ec2e1f1f769afcdb8091d7fd58362667aec03137b8114b12",
("val", "bridge"): "83216a2974d6068c2e1d18086006e7380ff58540216f955ce87fe049b460cb0d",
("val", "church_outdoor"): "34635b7547a3e51a15f942a4a4082dd6bc9cca381a953515cb2275c0eed50584",
("val", "classroom"): "5e0e9a375d94091dfe1fa3be87d4a92f41c03f1c0b8e376acc7e05651de512d7",
("val", "conference_room"): "927c94df52e10b9b374748c2b83b28b5860e946b3186dfd587985e274834650f",
("val", "dining_room"): "bd604d4b91bb5a9611d4e0b85475efd20758390d1a4eb57b53973fcbb5aa8ab6",
("val", "kitchen"): "329165f35ec61c4cf49f809246de300b8baad3ffcbda1ac30c27bdd32c84369a",
("val", "living_room"): "30a23d9a3db5414e9c97865f60ffb2ee973bfa658a23dbca7188ea514c97c9fc",
("val", "restaurant"): "efaa7bcb898ad6cb73b07b89fec3a9c670f4622912eea22fab3986c2cf9a1c20",
("val", "tower"): "7f5257847bc01f4e40d4a1b3e24dd8fcd37063f12ca8cf31e726c2ee0b1ae104",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
url_root = "http://dl.yf.io/lsun/scenes"
if config.split == "test":
return [
LsunHttpResource(
f"{url_root}/test_lmdb.zip",
sha256="5ee4f929363f26d1f3c7db6e40e3f7a8415cf777b3c5527f5f38bf3e9520ff22",
)
]
else:
return [
LsunHttpResource(
f"{url_root}/{category}_{config.split}_lmdb.zip",
sha256=self._CHECKSUMS[(config.split, category)],
)
for category in self.categories
]

_FOLDER_PATTERN = re.compile(r"(?P<category>\w*?)_(?P<split>(train|val))_lmdb")

def _collate_and_decode_sample(
self,
data: Tuple[str, bytes, io.BytesIO],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
path, key, buffer = data

match = self._FOLDER_PATTERN.match(pathlib.Path(path).parent.name)
if match:
category = match["category"]
label = Label(self.categories.index(category), category=category)
else:
label = None

return dict(
path=path,
key=key,
image=decoder(buffer) if decoder else buffer,
label=label,
)

def _filepath_fn(self, path: str) -> str:
return str(pathlib.Path(path).joinpath("keys.cache"))

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = Concater(*resource_dps)

# LMDB datasets are indexed, but extracting all keys is expensive. Since we need them for shuffling, we cache
# the keys on disk and subsequently only read them from there.
dp = OnDiskCacheHolder(dp, filepath_fn=self._filepath_fn)
dp = LmdbKeyExtractor(dp).end_caching(mode="wb", same_filepath_fn=True, skip_read=True)

dp = LmdbLoader(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = LmdbReader(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
12 changes: 8 additions & 4 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,16 @@ def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[
return functools.partial(_path_accessor_closure, getter=getter)


def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool:
return accessor(data) == value
def _path_comparator_closure(
data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D, inv: bool
) -> bool:
return (accessor(data) == value) ^ inv


def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]:
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
def path_comparator(
getter: Union[str, Callable[[pathlib.Path], D]], value: D, *, inv: bool = False
) -> Callable[[Tuple[str, Any]], bool]:
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value, inv=inv)


class CompressionType(enum.Enum):
Expand Down