Skip to content

enable big_model_inference on xpu #3595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/big_model_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pip install transformers
To reproduce or test a new setup, run

```py
python inference_acc.py model_name
python big_model_inference.py model_name
```

This script supports `gpt-j-6b`, `gpt-neox`, `opt` (30B version) and `T0pp` out of the box, but you can specify any valid checkpoint for `model_name`.
Expand Down Expand Up @@ -43,4 +43,4 @@ Note on the results:

You will also note that Accelerate does not use anymore GPU and CPU RAM than necessary:
- peak GPU memory is exactly the size of the model put on a given GPU
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.
- peak CPU memory is either the size of the biggest checkpoint shard or the part of the model offloaded on CPU, whichever is bigger.
28 changes: 17 additions & 11 deletions benchmarks/big_model_inference/measures_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
import torch


from accelerate.test_utils.testing import get_backend

torch_device_type, _, _ = get_backend()
torch_accelerator_module = getattr(torch, torch_device_type, torch.cuda)


class PeakCPUMemory:
def __init__(self):
self.process = psutil.Process()
Expand Down Expand Up @@ -54,16 +60,16 @@ def start_measure():
measures = {"time": time.time()}

gc.collect()
torch.cuda.empty_cache()
torch_accelerator_module.empty_cache()

# CPU mem
measures["cpu"] = psutil.Process().memory_info().rss
cpu_peak_tracker.start()

# GPU mem
for i in range(torch.cuda.device_count()):
measures[str(i)] = torch.cuda.memory_allocated(i)
torch.cuda.reset_peak_memory_stats()
for i in range(torch_accelerator_module.device_count()):
measures[str(i)] = torch_accelerator_module.memory_allocated(i)
torch_accelerator_module.reset_peak_memory_stats()

return measures

Expand All @@ -73,26 +79,26 @@ def end_measure(start_measures):
measures = {"time": time.time() - start_measures["time"]}

gc.collect()
torch.cuda.empty_cache()
torch_accelerator_module.empty_cache()

# CPU mem
measures["cpu"] = (psutil.Process().memory_info().rss - start_measures["cpu"]) / 2**20
measures["cpu-peak"] = (cpu_peak_tracker.stop() - start_measures["cpu"]) / 2**20

# GPU mem
for i in range(torch.cuda.device_count()):
measures[str(i)] = (torch.cuda.memory_allocated(i) - start_measures[str(i)]) / 2**20
measures[f"{i}-peak"] = (torch.cuda.max_memory_allocated(i) - start_measures[str(i)]) / 2**20
for i in range(torch_accelerator_module.device_count()):
measures[str(i)] = (torch_accelerator_module.memory_allocated(i) - start_measures[str(i)]) / 2**20
measures[f"{i}-peak"] = (torch_accelerator_module.max_memory_allocated(i) - start_measures[str(i)]) / 2**20

return measures


def log_measures(measures, description):
print(f"{description}:")
print(f"- Time: {measures['time']:.2f}s")
for i in range(torch.cuda.device_count()):
print(f"- GPU {i} allocated: {measures[str(i)]:.2f}MiB")
for i in range(torch_accelerator_module.device_count()):
print(f"- {torch_device_type} {i} allocated: {measures[str(i)]:.2f}MiB")
peak = measures[f"{i}-peak"]
print(f"- GPU {i} peak: {peak:.2f}MiB")
print(f"- {torch_device_type} {i} peak: {peak:.2f}MiB")
print(f"- CPU RAM allocated: {measures['cpu']:.2f}MiB")
print(f"- CPU RAM peak: {measures['cpu-peak']:.2f}MiB")
Loading