Skip to content

PyTorch datasets don't support multiprocessing #74

@jeremyjordan

Description

@jeremyjordan

PyTorch's Dataloader has an argument for num_workers which can fetch items from your dataset in parallel using multiprocessing, but this requires your dataset to be able to be pickled so Python can distribute it across multiple processes.

Currently, if you try to use multiple workers for a muspy dataset you get the following error:

AttributeError: Can't pickle local object 'Dataset.to_pytorch_dataset.<locals>.TorchRepresentationDataset'

There's more context on the pickle issue in this Stackoverflow thread.

Here's a minimal reproducible example to test it out yourself:

import muspy
import torch

haydn = muspy. HaydnOp20Dataset("data/", download_and_extract=True).convert()
dataset = haydn.to_pytorch_dataset(representation="pianoroll")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2)
batch = next(iter(dataloader))

I'm happy to open a PR with a fix for this, it mostly involves moving TorchRepresentationDataset and TorchMusicFactoryDataset to be defined outside of to_pytorch_dataset.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions