@@ -56,6 +56,36 @@ class StateDictType(enum.Enum):
5656 ".to_out.lora_magnitude_vector" : ".to_out.0.lora_magnitude_vector" ,
5757}
5858
59+ CONTROL_LORA_TO_DIFFUSERS = {
60+ ".to_q.down" : ".to_q.lora_A.weight" ,
61+ ".to_q.up" : ".to_q.lora_B.weight" ,
62+ ".to_k.down" : ".to_k.lora_A.weight" ,
63+ ".to_k.up" : ".to_k.lora_B.weight" ,
64+ ".to_v.down" : ".to_v.lora_A.weight" ,
65+ ".to_v.up" : ".to_v.lora_B.weight" ,
66+ ".to_out.0.down" : ".to_out.0.lora_A.weight" ,
67+ ".to_out.0.up" : ".to_out.0.lora_B.weight" ,
68+ ".ff.net.0.proj.down" : ".ff.net.0.proj.lora_A.weight" ,
69+ ".ff.net.0.proj.up" : ".ff.net.0.proj.lora_B.weight" ,
70+ ".ff.net.2.down" : ".ff.net.2.lora_A.weight" ,
71+ ".ff.net.2.up" : ".ff.net.2.lora_B.weight" ,
72+ ".proj_in.down" : ".proj_in.lora_A.weight" ,
73+ ".proj_in.up" : ".proj_in.lora_B.weight" ,
74+ ".proj_out.down" : ".proj_out.lora_A.weight" ,
75+ ".proj_out.up" : ".proj_out.lora_B.weight" ,
76+ ".conv.down" : ".conv.lora_A.weight" ,
77+ ".conv.up" : ".conv.lora_B.weight" ,
78+ ** {f".conv{ i } .down" : f".conv{ i } .lora_A.weight" for i in range (1 , 3 )},
79+ ** {f".conv{ i } .up" : f".conv{ i } .lora_B.weight" for i in range (1 , 3 )},
80+ "conv_in.down" : "conv_in.lora_A.weight" ,
81+ "conv_in.up" : "conv_in.lora_B.weight" ,
82+ ".conv_shortcut.down" : ".conv_shortcut.lora_A.weight" ,
83+ ".conv_shortcut.up" : ".conv_shortcut.lora_B.weight" ,
84+ ** {f".linear_{ i } .down" : f".linear_{ i } .lora_A.weight" for i in range (1 , 3 )},
85+ ** {f".linear_{ i } .up" : f".linear_{ i } .lora_B.weight" for i in range (1 , 3 )},
86+ "time_emb_proj.down" : "time_emb_proj.lora_A.weight" ,
87+ "time_emb_proj.up" : "time_emb_proj.lora_B.weight" ,
88+ }
5989
6090DIFFUSERS_TO_PEFT = {
6191 ".q_proj.lora_linear_layer.up" : ".q_proj.lora_B" ,
@@ -259,6 +289,155 @@ def convert_unet_state_dict_to_peft(state_dict):
259289 return convert_state_dict (state_dict , mapping )
260290
261291
292+ def convert_sai_sd_control_lora_state_dict_to_peft (state_dict ):
293+ def _convert_controlnet_to_diffusers (state_dict ):
294+ is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
295+ logger .info (f"Using ControlNet lora ({ 'SDXL' if is_sdxl else 'SD15' } )" )
296+
297+ # Retrieves the keys for the input blocks only
298+ num_input_blocks = len ({"." .join (layer .split ("." )[:2 ]) for layer in state_dict if "input_blocks" in layer })
299+ input_blocks = {
300+ layer_id : [key for key in state_dict if f"input_blocks.{ layer_id } " in key ]
301+ for layer_id in range (num_input_blocks )
302+ }
303+ layers_per_block = 2
304+
305+ # op blocks
306+ op_blocks = [key for key in state_dict if "0.op" in key ]
307+
308+ converted_state_dict = {}
309+ # Conv in layers
310+ for key in input_blocks [0 ]:
311+ diffusers_key = key .replace ("input_blocks.0.0" , "conv_in" )
312+ converted_state_dict [diffusers_key ] = state_dict .get (key )
313+
314+ # controlnet time embedding blocks
315+ time_embedding_blocks = [key for key in state_dict if "time_embed" in key ]
316+ for key in time_embedding_blocks :
317+ diffusers_key = key .replace ("time_embed.0" , "time_embedding.linear_1" ).replace (
318+ "time_embed.2" , "time_embedding.linear_2"
319+ )
320+ converted_state_dict [diffusers_key ] = state_dict .get (key )
321+
322+ # controlnet label embedding blocks
323+ label_embedding_blocks = [key for key in state_dict if "label_emb" in key ]
324+ for key in label_embedding_blocks :
325+ diffusers_key = key .replace ("label_emb.0.0" , "add_embedding.linear_1" ).replace (
326+ "label_emb.0.2" , "add_embedding.linear_2"
327+ )
328+ converted_state_dict [diffusers_key ] = state_dict .get (key )
329+
330+ # Down blocks
331+ for i in range (1 , num_input_blocks ):
332+ block_id = (i - 1 ) // (layers_per_block + 1 )
333+ layer_in_block_id = (i - 1 ) % (layers_per_block + 1 )
334+
335+ resnets = [
336+ key for key in input_blocks [i ] if f"input_blocks.{ i } .0" in key and f"input_blocks.{ i } .0.op" not in key
337+ ]
338+ for key in resnets :
339+ diffusers_key = (
340+ key .replace ("in_layers.0" , "norm1" )
341+ .replace ("in_layers.2" , "conv1" )
342+ .replace ("out_layers.0" , "norm2" )
343+ .replace ("out_layers.3" , "conv2" )
344+ .replace ("emb_layers.1" , "time_emb_proj" )
345+ .replace ("skip_connection" , "conv_shortcut" )
346+ )
347+ diffusers_key = diffusers_key .replace (
348+ f"input_blocks.{ i } .0" , f"down_blocks.{ block_id } .resnets.{ layer_in_block_id } "
349+ )
350+ converted_state_dict [diffusers_key ] = state_dict .get (key )
351+
352+ if f"input_blocks.{ i } .0.op.bias" in state_dict :
353+ for key in [key for key in op_blocks if f"input_blocks.{ i } .0.op" in key ]:
354+ diffusers_key = key .replace (
355+ f"input_blocks.{ i } .0.op" , f"down_blocks.{ block_id } .downsamplers.0.conv"
356+ )
357+ converted_state_dict [diffusers_key ] = state_dict .get (key )
358+
359+ attentions = [key for key in input_blocks [i ] if f"input_blocks.{ i } .1" in key ]
360+ if attentions :
361+ for key in attentions :
362+ diffusers_key = key .replace (
363+ f"input_blocks.{ i } .1" , f"down_blocks.{ block_id } .attentions.{ layer_in_block_id } "
364+ )
365+ converted_state_dict [diffusers_key ] = state_dict .get (key )
366+
367+ # controlnet down blocks
368+ for i in range (num_input_blocks ):
369+ converted_state_dict [f"controlnet_down_blocks.{ i } .weight" ] = state_dict .get (f"zero_convs.{ i } .0.weight" )
370+ converted_state_dict [f"controlnet_down_blocks.{ i } .bias" ] = state_dict .get (f"zero_convs.{ i } .0.bias" )
371+
372+ # Retrieves the keys for the middle blocks only
373+ num_middle_blocks = len ({"." .join (layer .split ("." )[:2 ]) for layer in state_dict if "middle_block" in layer })
374+ middle_blocks = {
375+ layer_id : [key for key in state_dict if f"middle_block.{ layer_id } " in key ]
376+ for layer_id in range (num_middle_blocks )
377+ }
378+
379+ # Mid blocks
380+ for key in middle_blocks .keys ():
381+ diffusers_key = max (key - 1 , 0 )
382+ if key % 2 == 0 :
383+ for k in middle_blocks [key ]:
384+ diffusers_key_hf = (
385+ k .replace ("in_layers.0" , "norm1" )
386+ .replace ("in_layers.2" , "conv1" )
387+ .replace ("out_layers.0" , "norm2" )
388+ .replace ("out_layers.3" , "conv2" )
389+ .replace ("emb_layers.1" , "time_emb_proj" )
390+ .replace ("skip_connection" , "conv_shortcut" )
391+ )
392+ diffusers_key_hf = diffusers_key_hf .replace (
393+ f"middle_block.{ key } " , f"mid_block.resnets.{ diffusers_key } "
394+ )
395+ converted_state_dict [diffusers_key_hf ] = state_dict .get (k )
396+ else :
397+ for k in middle_blocks [key ]:
398+ diffusers_key_hf = k .replace (f"middle_block.{ key } " , f"mid_block.attentions.{ diffusers_key } " )
399+ converted_state_dict [diffusers_key_hf ] = state_dict .get (k )
400+
401+ # mid block
402+ converted_state_dict ["controlnet_mid_block.weight" ] = state_dict .get ("middle_block_out.0.weight" )
403+ converted_state_dict ["controlnet_mid_block.bias" ] = state_dict .get ("middle_block_out.0.bias" )
404+
405+ # controlnet cond embedding blocks
406+ cond_embedding_blocks = {
407+ "." .join (layer .split ("." )[:2 ])
408+ for layer in state_dict
409+ if "input_hint_block" in layer
410+ and ("input_hint_block.0" not in layer )
411+ and ("input_hint_block.14" not in layer )
412+ }
413+ num_cond_embedding_blocks = len (cond_embedding_blocks )
414+
415+ for idx in range (1 , num_cond_embedding_blocks + 1 ):
416+ diffusers_idx = idx - 1
417+ cond_block_id = 2 * idx
418+
419+ converted_state_dict [f"controlnet_cond_embedding.blocks.{ diffusers_idx } .weight" ] = state_dict .get (
420+ f"input_hint_block.{ cond_block_id } .weight"
421+ )
422+ converted_state_dict [f"controlnet_cond_embedding.blocks.{ diffusers_idx } .bias" ] = state_dict .get (
423+ f"input_hint_block.{ cond_block_id } .bias"
424+ )
425+
426+ for key in [key for key in state_dict if "input_hint_block.0" in key ]:
427+ diffusers_key = key .replace ("input_hint_block.0" , "controlnet_cond_embedding.conv_in" )
428+ converted_state_dict [diffusers_key ] = state_dict .get (key )
429+
430+ for key in [key for key in state_dict if "input_hint_block.14" in key ]:
431+ diffusers_key = key .replace ("input_hint_block.14" , "controlnet_cond_embedding.conv_out" )
432+ converted_state_dict [diffusers_key ] = state_dict .get (key )
433+
434+ return converted_state_dict
435+
436+ state_dict = _convert_controlnet_to_diffusers (state_dict )
437+ mapping = CONTROL_LORA_TO_DIFFUSERS
438+ return convert_state_dict (state_dict , mapping )
439+
440+
262441def convert_all_state_dict_to_peft (state_dict ):
263442 r"""
264443 Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
0 commit comments