Skip to content

feat: use datasets.IterableDataset shard if possible. #3583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ValMystletainn
Copy link

What does this PR do?

Add support for datasets.IterableDataset sharding if pass it to accelerator.prepare.
Use the n_shard rather than IterableDatasetShard to reduce the data reading overhead, and make different rank reading different data shard efficient.

Fixes #3547

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc

When `accelerator.prepare` is called on a
`datasets.IterableDataset`, use the `shard` method to
split the dataset across the available processes. This
allows for more efficient data loading and processing.
Without load and slice overhead of `IterableDatasetShard`
@ValMystletainn
Copy link
Author

I guess I shoud write a test for it and change somewhere in docs.

However, I readthough the tests/test_data_loader.py and find no where to starts a multi process test suit for this function.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! A test would be nice to have ! You can put the test in test_distributed_data_loop.py. We run these tests on 2 gpus. Check test_distributed_data_loop test in test_multi-gpu.py file

Comment on lines +1198 to +1199
if (
isinstance(new_dataset, getattr(sys.modules.get("datasets"), "IterableDataset", type(None)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could work but let's use check if dataset is available (is_datasets_available) and import the class IterableDataset from there to perform the check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I write in this style rather than use

if is_datasets_available():
    from datasets import IterableDataset as DatasetsIterableDatasets
...

if isinstance(new_dataset, DatasetsIterableDatasets):
    ...

is aiming to reduce import overhead like this codesnippet. it check the object is torch.Tensor or not, and skip to import the heavy pytorch package if there is no torch.Tensor object at all.

however, I think the import overhead of datasets is not so heavy like torch, so if it's for the readability and maintainability, I will change to this style

if is_datasets_available():
    from datasets import IterableDataset as DatasetsIterableDatasets
...

So what do you think about it, the original version or the import and check version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it should be fine with the overhead. We only call this function once so it shouldn't create a huge overhead. Please go with the import + check version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

datasets Iterable Dataset sharding support
2 participants