Skip to content

Commit fd084dd

Browse files
committed
update
2 parents 9445c4b + 425a715 commit fd084dd

File tree

3 files changed

+89
-19
lines changed

3 files changed

+89
-19
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
934934
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
935935
you want to load multiple adapters and free some GPU memory.
936936
937+
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
938+
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
939+
GPU before using those LoRA adapters for inference.
940+
941+
```python
942+
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
943+
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
944+
>>> pipe.set_adapters("adapter-1")
945+
>>> image_1 = pipe(**kwargs)
946+
>>> # switch to adapter-2, offload adapter-1
947+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
948+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
949+
>>> pipe.set_adapters("adapter-2")
950+
>>> image_2 = pipe(**kwargs)
951+
>>> # switch back to adapter-1, offload adapter-2
952+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
953+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
954+
>>> pipe.set_adapters("adapter-1")
955+
>>> ...
956+
```
957+
937958
Args:
938959
adapter_names (`List[str]`):
939960
List of adapters to send device to.
@@ -949,6 +970,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
949970
for module in model.modules():
950971
if isinstance(module, BaseTunerLayer):
951972
for adapter_name in adapter_names:
973+
if adapter_name not in module.lora_A:
974+
# it is sufficient to check lora_A
975+
continue
976+
952977
module.lora_A[adapter_name].to(device)
953978
module.lora_B[adapter_name].to(device)
954979
# this is a param, not a module, so device placement is not in-place -> re-assign

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18251825
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
18261826
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
18271827
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
1828+
has_time_projection_weight = any(
1829+
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
1830+
)
18281831

1829-
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
1830-
if diff_keys:
1831-
for diff_k in diff_keys:
1832-
param = original_state_dict[diff_k]
1833-
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1834-
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1835-
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1836-
# is okay to ignore because they do not affect the model output in a significant manner.
1837-
threshold = 1.6e-2
1838-
absdiff = param.abs().max() - param.abs().min()
1839-
all_zero = torch.all(param == 0).item()
1840-
all_absdiff_lower_than_threshold = absdiff < threshold
1841-
if all_zero or all_absdiff_lower_than_threshold:
1842-
logger.debug(
1843-
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1844-
)
1845-
original_state_dict.pop(diff_k)
1832+
for key in list(original_state_dict.keys()):
1833+
if key.endswith((".diff", ".diff_b")) and "norm" in key:
1834+
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1835+
# in future if needed and they are not zeroed.
1836+
original_state_dict.pop(key)
1837+
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
1838+
1839+
if "time_projection" in key and not has_time_projection_weight:
1840+
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1841+
# our lora config adds the time proj lora layers, but we don't have the weights for them.
1842+
# CausVid lora has the weight keys and the bias keys.
1843+
original_state_dict.pop(key)
18461844

18471845
# For the `diff_b` keys, we treat them as lora_bias.
18481846
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias

tests/lora/test_lora_layers_sd.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_integration_move_lora_cpu(self):
121121

122122
self.assertTrue(
123123
check_if_lora_correctly_set(pipe.unet),
124-
"Lora not correctly set in text encoder",
124+
"Lora not correctly set in unet",
125125
)
126126

127127
# We will offload the first adapter in CPU and check if the offloading
@@ -188,7 +188,7 @@ def test_integration_move_lora_dora_cpu(self):
188188

189189
self.assertTrue(
190190
check_if_lora_correctly_set(pipe.unet),
191-
"Lora not correctly set in text encoder",
191+
"Lora not correctly set in unet",
192192
)
193193

194194
for name, param in pipe.unet.named_parameters():
@@ -222,6 +222,53 @@ def test_lora_set_adapters_scenarios(self, scenario):
222222
scenario=scenario, expected_atol=expected_atol, expected_rtol=expected_rtol
223223
)
224224

225+
@slow
226+
@require_torch_accelerator
227+
def test_integration_set_lora_device_different_target_layers(self):
228+
# fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
229+
# layers, see #11833
230+
from peft import LoraConfig
231+
232+
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
233+
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
234+
# configs partly target the same, partly different layers
235+
config0 = LoraConfig(target_modules=["to_k", "to_v"])
236+
config1 = LoraConfig(target_modules=["to_k", "to_q"])
237+
pipe.unet.add_adapter(config0, adapter_name="adapter-0")
238+
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
239+
pipe = pipe.to(torch_device)
240+
241+
self.assertTrue(
242+
check_if_lora_correctly_set(pipe.unet),
243+
"Lora not correctly set in unet",
244+
)
245+
246+
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
247+
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
248+
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
249+
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
250+
self.assertTrue(modules_adapter_0 - modules_adapter_1)
251+
self.assertTrue(modules_adapter_1 - modules_adapter_0)
252+
253+
# setting both separately works
254+
pipe.set_lora_device(["adapter-0"], "cpu")
255+
pipe.set_lora_device(["adapter-1"], "cpu")
256+
257+
for name, module in pipe.unet.named_modules():
258+
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
259+
self.assertTrue(module.weight.device == torch.device("cpu"))
260+
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
261+
self.assertTrue(module.weight.device == torch.device("cpu"))
262+
263+
# setting both at once also works
264+
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
265+
266+
for name, module in pipe.unet.named_modules():
267+
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
268+
self.assertTrue(module.weight.device != torch.device("cpu"))
269+
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
270+
self.assertTrue(module.weight.device != torch.device("cpu"))
271+
225272

226273
@slow
227274
@nightly

0 commit comments

Comments
 (0)