Skip to content

(fix) remove sampler_is_batch_sampler code in prepare_data_loader(..) #3469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

suzyahyah
Copy link

@suzyahyah suzyahyah commented Mar 31, 2025

What does this PR do?

This PR fixes various confusions wrt to torch.utils.data.DataLoader, torch.utils.data.BatchSampler, and accelerate/data_loader.py:: prepare_data_loader(..).

#3322 Edit, not quite this
#3014
#2091

Motivation

  1. accelerate/data_loader.py has various patches allowing for a BatchSampler object to be passed as an argument to sampler.

  2. In the code, it allows this behavior using sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler).

  3. Allowing this is unintuitive for developers as it directly conflicts with torch.utils.data.DataLoader documentation, when accelerate should only wrap the pytorch dataloader and not change/allow different logic or arguments.

  4. It also permits various unintended behavior from other libraries relying on accelerate. For instance, when passing a custom BatchSampler to the dataloader argument in HuggingFace Trainer, the sampler kwargs is currently used, regardless of whether the sampler is BatchSampler or just RandomSampler. https://github.com/huggingface/transformers/blob/3b07ca78bb696825feee3e976795fab58f2b6d0c/src/transformers/trainer.py#L1026 (I'll be making a separate PR in HuggingFace on this)

  5. I believe this is due to a misunderstanding stemming from this Issue #Error in prepared DataLoader with BatchSampler #679. As @sgugger initially suspected, this is probably a typo or misunderstanding from HuggingFace.

This should not be allowed behavior, based on the following basic test case which will throw the datasets/formatting/formatting.py: TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'

from torch.utils.data import BatchSampler, RandomSampler
from datasets import load_dataset

ds = load_dataset("wikitext", "wikitext-103-raw-v1", split='validation')

train_dataloader = DataLoader(
    ds,
    sampler=BatchSampler(RandomSampler(dev_ds), batch_size=32, drop_last=False),
    num_workers=0,  # Adjust based on your setup
    pin_memory=True,
)

for batch in train_dataloader:
  print(batch)

Official Pytorch Documentation for reference:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

This PR

  1. Reverses this historical PR: Fix DataLoader with samplers that are batch samplers #687 which supports the wrong logic.
  2. Removes all traces of sampler_is_batch_sampler, simplifying the code
  3. Immediately throws an error Edit: Throws a warning if a BatchSampler had been passed as an argument to sampler

Tests

Passes all tests in tests/test_data_loader.py, but does not introduce any new tests. Open to suggestions.

Considerations

  1. The PR Immediately throws an error if a BatchSampler had been passed as an argument to Sampler. Technically this error should be thrown earlier or in torch.utils.data.Sampler but we can make it explicit that the problem is not coming from the Accelerate Library, since it had previously been allowed.

  2. Upon reading the torch.utils.data source code v2.6.0, I figured that torch.utils.data.DataLoader will attempt to construct a BatchSampler from a Sampler if batch_sampler=None, and will also construct a default sampler if sampler=None.

This means, we should always be able to recover a dataloader.batch_sampler when wrapping accelerate around an already constructed pytorch dataloader, and there is no need to check whether sampler_is_batch_sampler again, if the intent is just to get "new_batch_sampler".

Before submitting

Who can review?

@SunMarc @BenjaminBossan @zach-huggingface @muellerzr

Comment on lines 1153 to 1159
raise ValueError(
"Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please show the exact paragraph/section of the pytorch docs that state this? You reference many issues, but I dont' see where you found this

Copy link
Author

@suzyahyah suzyahyah Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick reply! Yes it’s referenced under motivation #5.

Official PyTorch Documentation for Reference showing two arguments for sampler and batch_sampler,

“””users may use the sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.
A custom Sampler that yields a list of batch indices at a time can be passed as the batch_sampler argument.”””

And in the docs they write this is “mutually exclusive with sampler”

batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

Copy link
Author

@suzyahyah suzyahyah Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass a BatchSampler (yields a list of indices) as an argument to sampler, we will get a list of list of indices, because pytorch will convert whatever was passed in the sampler argument to “batch mode”.

Maybe it’s legal behavior for some very custom parallel processing on batch of batches, but it seems rare and pretty unlikely to me. Should I convert ValueError to Warning

@SunMarc
Copy link
Member

SunMarc commented Apr 1, 2025

can you rebase again ? the diff doesn't look correct

@suzyahyah suzyahyah force-pushed the fix/data_loader_batch_sampler branch from 17fc17a to e228780 Compare April 1, 2025 14:41
@suzyahyah
Copy link
Author

can you rebase again ? the diff doesn't look correct

Yeah I think I fixed it now.

I dont' see where you found this

Re @muellerzr's very valid concern, I changed this to a warning and reference PyTorch documentation.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc
Copy link
Member

SunMarc commented Apr 9, 2025

Some tests are still failing @suzyahyah

@suzyahyah suzyahyah force-pushed the fix/data_loader_batch_sampler branch from e228780 to b90f9a6 Compare April 12, 2025 11:07
@suzyahyah suzyahyah force-pushed the fix/data_loader_batch_sampler branch from b90f9a6 to b60fd09 Compare April 12, 2025 11:20
@suzyahyah
Copy link
Author

Thanks, @SunMarc

The tests were failing because I misunderstood and had wrongly changed the logic in the function src/accelerate.data_loader.py: set_sampler(). I have reverted the logic in that function back to main branch.

After rebase against main, the tests successfully run (w/o hardware accelerators):

make test_core
make test_prod
pytest tests/test_data_loader.py
make test

Quality checks run:

make style
make quality

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work ! I feel like it is better if we first do a deprecation message instead of breaking everything right now WDYT @suzyahyah ? This will help users to fix their code. Also, it would be nice to fix the HF docs, otherwise users will still use sampler arg by default

Comment on lines -980 to -981
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the isinstance(dataloader.sampler, BatchSampler) check

Comment on lines +1154 to +1160
if isinstance(dataloader.sampler, BatchSampler):
logger.warning(
"BatchSampler was passed to sampler argument."
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning is nice, maybe we should add a deprecation message also saying that we won't allow passing BatchSampler to sampler anymore ?

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@SunMarc
Copy link
Member

SunMarc commented May 12, 2025

LMK if you plan to finish the PR or I will leave it as a feature request / bug to fix @suzyahyah

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants