Skip to content

Commit 3649d7b

Browse files
DN6sayakpaul
andauthored
Follow up for Group Offload to Disk (#11760)
* update * update * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 10c36e0 commit 3649d7b

File tree

1 file changed

+81
-61
lines changed

1 file changed

+81
-61
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 81 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,58 @@ def _pinned_memory_tensors(self):
132132
finally:
133133
pinned_dict = None
134134

135+
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
136+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
137+
if self.record_stream and current_stream is not None:
138+
tensor.data.record_stream(current_stream)
139+
140+
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
141+
for group_module in self.modules:
142+
for param in group_module.parameters():
143+
source = pinned_memory[param] if pinned_memory else param.data
144+
self._transfer_tensor_to_device(param, source, current_stream)
145+
for buffer in group_module.buffers():
146+
source = pinned_memory[buffer] if pinned_memory else buffer.data
147+
self._transfer_tensor_to_device(buffer, source, current_stream)
148+
149+
for param in self.parameters:
150+
source = pinned_memory[param] if pinned_memory else param.data
151+
self._transfer_tensor_to_device(param, source, current_stream)
152+
153+
for buffer in self.buffers:
154+
source = pinned_memory[buffer] if pinned_memory else buffer.data
155+
self._transfer_tensor_to_device(buffer, source, current_stream)
156+
157+
def _onload_from_disk(self, current_stream):
158+
if self.stream is not None:
159+
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
160+
161+
for key, tensor_obj in self.key_to_tensor.items():
162+
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
163+
164+
with self._pinned_memory_tensors() as pinned_memory:
165+
for key, tensor_obj in self.key_to_tensor.items():
166+
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
167+
168+
self.cpu_param_dict.clear()
169+
170+
else:
171+
onload_device = (
172+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
173+
)
174+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
175+
for key, tensor_obj in self.key_to_tensor.items():
176+
tensor_obj.data = loaded_tensors[key]
177+
178+
def _onload_from_memory(self, current_stream):
179+
if self.stream is not None:
180+
with self._pinned_memory_tensors() as pinned_memory:
181+
self._process_tensors_from_modules(pinned_memory, current_stream)
182+
else:
183+
self._process_tensors_from_modules(None, current_stream)
184+
135185
@torch.compiler.disable()
136186
def onload_(self):
137-
r"""Onloads the group of modules to the onload_device."""
138187
torch_accelerator_module = (
139188
getattr(torch, torch.accelerator.current_accelerator().type)
140189
if hasattr(torch, "accelerator")
@@ -172,67 +221,30 @@ def onload_(self):
172221
self.stream.synchronize()
173222

174223
with context:
175-
if self.stream is not None:
176-
with self._pinned_memory_tensors() as pinned_memory:
177-
for group_module in self.modules:
178-
for param in group_module.parameters():
179-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
180-
if self.record_stream:
181-
param.data.record_stream(current_stream)
182-
for buffer in group_module.buffers():
183-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
184-
if self.record_stream:
185-
buffer.data.record_stream(current_stream)
186-
187-
for param in self.parameters:
188-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
189-
if self.record_stream:
190-
param.data.record_stream(current_stream)
191-
192-
for buffer in self.buffers:
193-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
194-
if self.record_stream:
195-
buffer.data.record_stream(current_stream)
196-
224+
if self.offload_to_disk_path:
225+
self._onload_from_disk(current_stream)
197226
else:
198-
for group_module in self.modules:
199-
for param in group_module.parameters():
200-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
201-
for buffer in group_module.buffers():
202-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
203-
204-
for param in self.parameters:
205-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
206-
207-
for buffer in self.buffers:
208-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
209-
if self.record_stream:
210-
buffer.data.record_stream(current_stream)
211-
212-
@torch.compiler.disable()
213-
def offload_(self):
214-
r"""Offloads the group of modules to the offload_device."""
215-
if self.offload_to_disk_path:
216-
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
217-
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
218-
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
219-
# we perform a write.
220-
# Check if the file has been saved in this session or if it already exists on disk.
221-
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
222-
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
223-
tensors_to_save = {
224-
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
225-
}
226-
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
227-
228-
# The group is now considered offloaded to disk for the rest of the session.
229-
self._is_offloaded_to_disk = True
230-
231-
# We do this to free up the RAM which is still holding the up tensor data.
232-
for tensor_obj in self.tensor_to_key.keys():
233-
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
234-
return
235-
227+
self._onload_from_memory(current_stream)
228+
229+
def _offload_to_disk(self):
230+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
231+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
232+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
233+
# we perform a write.
234+
# Check if the file has been saved in this session or if it already exists on disk.
235+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
236+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
237+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
238+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
239+
240+
# The group is now considered offloaded to disk for the rest of the session.
241+
self._is_offloaded_to_disk = True
242+
243+
# We do this to free up the RAM which is still holding the up tensor data.
244+
for tensor_obj in self.tensor_to_key.keys():
245+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
246+
247+
def _offload_to_memory(self):
236248
torch_accelerator_module = (
237249
getattr(torch, torch.accelerator.current_accelerator().type)
238250
if hasattr(torch, "accelerator")
@@ -257,6 +269,14 @@ def offload_(self):
257269
for buffer in self.buffers:
258270
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
259271

272+
@torch.compiler.disable()
273+
def offload_(self):
274+
r"""Offloads the group of modules to the offload_device."""
275+
if self.offload_to_disk_path:
276+
self._offload_to_disk()
277+
else:
278+
self._offload_to_memory()
279+
260280

261281
class GroupOffloadingHook(ModelHook):
262282
r"""

0 commit comments

Comments
 (0)