-
-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
PyTorch's Dataloader has an argument for num_workers
which can fetch items from your dataset in parallel using multiprocessing, but this requires your dataset to be able to be pickled so Python can distribute it across multiple processes.
Currently, if you try to use multiple workers for a muspy
dataset you get the following error:
AttributeError: Can't pickle local object 'Dataset.to_pytorch_dataset.<locals>.TorchRepresentationDataset'
There's more context on the pickle
issue in this Stackoverflow thread.
Here's a minimal reproducible example to test it out yourself:
import muspy
import torch
haydn = muspy. HaydnOp20Dataset("data/", download_and_extract=True).convert()
dataset = haydn.to_pytorch_dataset(representation="pianoroll")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2)
batch = next(iter(dataloader))
I'm happy to open a PR with a fix for this, it mostly involves moving TorchRepresentationDataset
and TorchMusicFactoryDataset
to be defined outside of to_pytorch_dataset
.
Metadata
Metadata
Assignees
Labels
No labels