Skip to content

CPU-Memory keeps accumulating during trainer.predict #19398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
surajpaib opened this issue Feb 2, 2024 · 6 comments · May be fixed by #20730
Open

CPU-Memory keeps accumulating during trainer.predict #19398

surajpaib opened this issue Feb 2, 2024 · 6 comments · May be fixed by #20730

Comments

@surajpaib
Copy link

surajpaib commented Feb 2, 2024

Bug description

This is very similar to closed issue #15656

I am working on predicting using PL Trainer on 3D images and these are huge, my process keeps getting killed when a large number of samples are to be predicted. I found #15656 and expected that to be the solution but setting return_predictions=False does not fix the memory accumulation.

What seems to work instead is adding a gc.collect() in the predict_loop. This keeps CPU memory usage constant as would be expected.

It seems like setting return_predictions=False should stop the memory accumulation but I'm confused as to why the gc.collect() is needed.

This is where the gc.collect() is applied: https://github.com/project-lighter/lighter/blob/07018bb2c66c0c8848bab748299e2c2d21c7d185/lighter/callbacks/writer/base.py#L120

I've also attached a memory log using scalene of the return predictions and the gc collect comparison. As you can see, there is no memory growth for gc collect.

Would you be able to provide any intuition on this? It would be much appreciated!

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

gc_collect.pdf
return_predictions_false.pdf

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @lantiga @Borda

@surajpaib surajpaib added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 2, 2024
@awaelchli
Copy link
Contributor

@surajpaib In addition to your gc.collect() call, I see you do

trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)]

but based on your description (return_predictions=False), this should already be an empty list. Can you confirm? In any case, I can't tell why it is necessary, but if you want we can add the gc.collect() call in the loop. If it doesn't impact the iteration speed / throughput (it might be expensive in certain situations).

@awaelchli awaelchli added trainer: predict performance and removed needs triage Waiting to be triaged by maintainers labels Feb 3, 2024
@ibro45
Copy link

ibro45 commented Feb 3, 2024

return_predictions=False wasn't working without gc.collect(). Since we needed to call gc.collect() anyway, we figured let's just clean the predictions manually right there too via trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)] and not deal with return_predictions until it's fixed.

@surajpaib
Copy link
Author

return_predictions=False wasn't working without gc.collect(). Since we needed to call gc.collect() anyway, we figured let's just clean the predictions manually right there too via trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)] and not deal with return_predictions until it's fixed.

To add to this, there is a minor difference in memory usage over time with and without clearing the trainer.predict_loop._predictions when gc.collect() is added.

Given that our batch inferences take long (3D images), the gc.collect() in comparison doesn't seem to have much of an influence on iteration speed. But this would need additional testing for the general case.

What I still don't get is how the memory accumulates when return_predictions are set to False. I assume this should not collect any predictions and therefore have no memory growth. Which doesn't seem to be the case.

@PierreSerr
Copy link

I have the same issue with PyTorch Lightning 2.5.0 and Torch 2.5.1 in a DDP setup.
Using your solution (clearing trainer.predict_loop._predictions and calling gc.collect()) in on_predict_batch_end worked.

@ved1beta ved1beta linked a pull request Apr 20, 2025 that will close this issue
7 tasks
Copy link

stale bot commented Apr 28, 2025

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 28, 2025
@ivme
Copy link

ivme commented May 14, 2025

I had a similar issue. Clearing trainer.predict_loop._predictions and calling gc.collect() per @surajpaib 's suggestion resolved it.

@stale stale bot removed the won't fix This will not be worked on label May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants