@@ -132,9 +132,58 @@ def _pinned_memory_tensors(self):
132
132
finally :
133
133
pinned_dict = None
134
134
135
+ def _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
136
+ tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
137
+ if self .record_stream and current_stream is not None :
138
+ tensor .data .record_stream (current_stream )
139
+
140
+ def _process_tensors_from_modules (self , pinned_memory = None , current_stream = None ):
141
+ for group_module in self .modules :
142
+ for param in group_module .parameters ():
143
+ source = pinned_memory [param ] if pinned_memory else param .data
144
+ self ._transfer_tensor_to_device (param , source , current_stream )
145
+ for buffer in group_module .buffers ():
146
+ source = pinned_memory [buffer ] if pinned_memory else buffer .data
147
+ self ._transfer_tensor_to_device (buffer , source , current_stream )
148
+
149
+ for param in self .parameters :
150
+ source = pinned_memory [param ] if pinned_memory else param .data
151
+ self ._transfer_tensor_to_device (param , source , current_stream )
152
+
153
+ for buffer in self .buffers :
154
+ source = pinned_memory [buffer ] if pinned_memory else buffer .data
155
+ self ._transfer_tensor_to_device (buffer , source , current_stream )
156
+
157
+ def _onload_from_disk (self , current_stream ):
158
+ if self .stream is not None :
159
+ loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
160
+
161
+ for key , tensor_obj in self .key_to_tensor .items ():
162
+ self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
163
+
164
+ with self ._pinned_memory_tensors () as pinned_memory :
165
+ for key , tensor_obj in self .key_to_tensor .items ():
166
+ self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
167
+
168
+ self .cpu_param_dict .clear ()
169
+
170
+ else :
171
+ onload_device = (
172
+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
173
+ )
174
+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
175
+ for key , tensor_obj in self .key_to_tensor .items ():
176
+ tensor_obj .data = loaded_tensors [key ]
177
+
178
+ def _onload_from_memory (self , current_stream ):
179
+ if self .stream is not None :
180
+ with self ._pinned_memory_tensors () as pinned_memory :
181
+ self ._process_tensors_from_modules (pinned_memory , current_stream )
182
+ else :
183
+ self ._process_tensors_from_modules (None , current_stream )
184
+
135
185
@torch .compiler .disable ()
136
186
def onload_ (self ):
137
- r"""Onloads the group of modules to the onload_device."""
138
187
torch_accelerator_module = (
139
188
getattr (torch , torch .accelerator .current_accelerator ().type )
140
189
if hasattr (torch , "accelerator" )
@@ -172,67 +221,30 @@ def onload_(self):
172
221
self .stream .synchronize ()
173
222
174
223
with context :
175
- if self .stream is not None :
176
- with self ._pinned_memory_tensors () as pinned_memory :
177
- for group_module in self .modules :
178
- for param in group_module .parameters ():
179
- param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
180
- if self .record_stream :
181
- param .data .record_stream (current_stream )
182
- for buffer in group_module .buffers ():
183
- buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
184
- if self .record_stream :
185
- buffer .data .record_stream (current_stream )
186
-
187
- for param in self .parameters :
188
- param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
189
- if self .record_stream :
190
- param .data .record_stream (current_stream )
191
-
192
- for buffer in self .buffers :
193
- buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
194
- if self .record_stream :
195
- buffer .data .record_stream (current_stream )
196
-
224
+ if self .offload_to_disk_path :
225
+ self ._onload_from_disk (current_stream )
197
226
else :
198
- for group_module in self .modules :
199
- for param in group_module .parameters ():
200
- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
201
- for buffer in group_module .buffers ():
202
- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
203
-
204
- for param in self .parameters :
205
- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
206
-
207
- for buffer in self .buffers :
208
- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
209
- if self .record_stream :
210
- buffer .data .record_stream (current_stream )
211
-
212
- @torch .compiler .disable ()
213
- def offload_ (self ):
214
- r"""Offloads the group of modules to the offload_device."""
215
- if self .offload_to_disk_path :
216
- # TODO: we can potentially optimize this code path by checking if the _all_ the desired
217
- # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
218
- # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
219
- # we perform a write.
220
- # Check if the file has been saved in this session or if it already exists on disk.
221
- if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
222
- os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
223
- tensors_to_save = {
224
- key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()
225
- }
226
- safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
227
-
228
- # The group is now considered offloaded to disk for the rest of the session.
229
- self ._is_offloaded_to_disk = True
230
-
231
- # We do this to free up the RAM which is still holding the up tensor data.
232
- for tensor_obj in self .tensor_to_key .keys ():
233
- tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
234
- return
235
-
227
+ self ._onload_from_memory (current_stream )
228
+
229
+ def _offload_to_disk (self ):
230
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
231
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
232
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
233
+ # we perform a write.
234
+ # Check if the file has been saved in this session or if it already exists on disk.
235
+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
236
+ os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
237
+ tensors_to_save = {key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()}
238
+ safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
239
+
240
+ # The group is now considered offloaded to disk for the rest of the session.
241
+ self ._is_offloaded_to_disk = True
242
+
243
+ # We do this to free up the RAM which is still holding the up tensor data.
244
+ for tensor_obj in self .tensor_to_key .keys ():
245
+ tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
246
+
247
+ def _offload_to_memory (self ):
236
248
torch_accelerator_module = (
237
249
getattr (torch , torch .accelerator .current_accelerator ().type )
238
250
if hasattr (torch , "accelerator" )
@@ -257,6 +269,14 @@ def offload_(self):
257
269
for buffer in self .buffers :
258
270
buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
259
271
272
+ @torch .compiler .disable ()
273
+ def offload_ (self ):
274
+ r"""Offloads the group of modules to the offload_device."""
275
+ if self .offload_to_disk_path :
276
+ self ._offload_to_disk ()
277
+ else :
278
+ self ._offload_to_memory ()
279
+
260
280
261
281
class GroupOffloadingHook (ModelHook ):
262
282
r"""
0 commit comments