-
Notifications
You must be signed in to change notification settings - Fork 1.1k
(fix) remove sampler_is_batch_sampler code in prepare_data_loader(..) #3469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
(fix) remove sampler_is_batch_sampler code in prepare_data_loader(..) #3469
Conversation
src/accelerate/data_loader.py
Outdated
raise ValueError( | ||
"Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please show the exact paragraph/section of the pytorch docs that state this? You reference many issues, but I dont' see where you found this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick reply! Yes it’s referenced under motivation #5.
Official PyTorch Documentation for Reference showing two arguments for sampler and batch_sampler,
“””users may use the sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.
A custom Sampler that yields a list of batch indices at a time can be passed as the batch_sampler argument.”””
And in the docs they write this is “mutually exclusive with sampler”
batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we pass a BatchSampler (yields a list of indices) as an argument to sampler, we will get a list of list of indices, because pytorch will convert whatever was passed in the sampler argument to “batch mode”.
Maybe it’s legal behavior for some very custom parallel processing on batch of batches, but it seems rare and pretty unlikely to me. Should I convert ValueError to Warning
can you rebase again ? the diff doesn't look correct |
17fc17a
to
e228780
Compare
Yeah I think I fixed it now.
Re @muellerzr's very valid concern, I changed this to a warning and reference PyTorch documentation. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Some tests are still failing @suzyahyah |
e228780
to
b90f9a6
Compare
b90f9a6
to
b60fd09
Compare
Thanks, @SunMarc The tests were failing because I misunderstood and had wrongly changed the logic in the function After rebase against main, the tests successfully run (w/o hardware accelerators):
Quality checks run:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work ! I feel like it is better if we first do a deprecation message instead of breaking everything right now WDYT @suzyahyah ? This will help users to fix their code. Also, it would be nice to fix the HF docs, otherwise users will still use sampler arg by default
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) | ||
if sampler_is_batch_sampler: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep the isinstance(dataloader.sampler, BatchSampler)
check
if isinstance(dataloader.sampler, BatchSampler): | ||
logger.warning( | ||
"BatchSampler was passed to sampler argument." | ||
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead." | ||
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning is nice, maybe we should add a deprecation message also saying that we won't allow passing BatchSampler
to sampler anymore ?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
LMK if you plan to finish the PR or I will leave it as a feature request / bug to fix @suzyahyah |
What does this PR do?
This PR fixes various confusions wrt to
torch.utils.data.DataLoader
,torch.utils.data.BatchSampler
, andaccelerate/data_loader.py:: prepare_data_loader(..)
.#3322Edit, not quite this#3014
#2091
Motivation
accelerate/data_loader.py
has various patches allowing for aBatchSampler
object to be passed as an argument tosampler
.In the code, it allows this behavior using
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
.Allowing this is unintuitive for developers as it directly conflicts with
torch.utils.data.DataLoader
documentation, when accelerate should only wrap the pytorch dataloader and not change/allow different logic or arguments.It also permits various unintended behavior from other libraries relying on accelerate. For instance, when passing a custom BatchSampler to the dataloader argument in HuggingFace Trainer, the
sampler
kwargs is currently used, regardless of whether the sampler is BatchSampler or just RandomSampler. https://github.com/huggingface/transformers/blob/3b07ca78bb696825feee3e976795fab58f2b6d0c/src/transformers/trainer.py#L1026 (I'll be making a separate PR in HuggingFace on this)I believe this is due to a misunderstanding stemming from this Issue #Error in prepared DataLoader with BatchSampler #679. As @sgugger initially suspected, this is probably a typo or misunderstanding from HuggingFace.
This should not be allowed behavior, based on the following basic test case which will throw the datasets/formatting/formatting.py:
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'
Official Pytorch Documentation for reference:
This PR
sampler_is_batch_sampler
, simplifying the codeImmediately throws an errorEdit: Throws a warning if aBatchSampler
had been passed as an argument tosampler
Tests
Passes all tests in tests/test_data_loader.py, but does not introduce any new tests. Open to suggestions.
Considerations
The PR Immediately throws an error if a BatchSampler had been passed as an argument to Sampler. Technically this error should be thrown earlier or in
torch.utils.data.Sampler
but we can make it explicit that the problem is not coming from the Accelerate Library, since it had previously been allowed.Upon reading the torch.utils.data source code v2.6.0, I figured that
torch.utils.data.DataLoader
will attempt to construct a BatchSampler from a Sampler ifbatch_sampler=None
, and will also construct a default sampler ifsampler=None
.This means, we should always be able to recover a
dataloader.batch_sampler
when wrapping accelerate around an already constructed pytorchdataloader
, and there is no need to check whethersampler_is_batch_sampler
again, if the intent is just to get "new_batch_sampler
".Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@SunMarc @BenjaminBossan @zach-huggingface @muellerzr