4
4
5
5
def load_fine_tune_model (base_model_id , saved_weights ):
6
6
# Load tokenizer and base model
7
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
7
8
tokenizer = AutoTokenizer .from_pretrained (base_model_id )
8
9
base_model = AutoModelForCausalLM .from_pretrained (base_model_id )
9
- base_model .to ("cpu" )
10
+ base_model .to (device )
10
11
11
12
# Create LoRA config - make sure these parameters match your training configuration
12
13
peft_config = LoraConfig (
@@ -15,39 +16,20 @@ def load_fine_tune_model(base_model_id, saved_weights):
15
16
lora_dropout = 0.05 ,
16
17
bias = "none" ,
17
18
task_type = "CAUSAL_LM" ,
18
- target_modules = [
19
- 'q_proj' ,
20
- 'k_proj' ,
21
- 'v_proj' ,
22
- ],
19
+ target_modules = ['q_proj' ,'k_proj' ,'v_proj' ],
23
20
)
24
21
25
22
# Initialize PeftModel
26
23
lora_model = PeftModel (base_model , peft_config )
27
24
28
25
# Load the saved weights
29
- state_dict = torch .load (saved_weights ,map_location = torch .device ('cpu' ))
30
-
31
- # Check the keys in the state dict
32
- print ("Original state dict keys:" , state_dict .keys ())
33
-
26
+ state_dict = torch .load (saved_weights ,map_location = device )
27
+
34
28
# Create new state dict with correct prefixes and structure
35
29
new_state_dict = {}
36
30
for key , value in state_dict .items ():
37
- if 'model.model.' in key :
38
- # If the key already has model.model., just add base_
39
- new_key = f"base_{ key } "
40
- elif 'model.' in key :
41
- # If the key has only model., add base_model.
42
- new_key = f"base_{ key } "
43
- else :
44
- # If the key has neither, add the full prefix
45
- new_key = f"base_model.model.{ key } "
46
-
47
- # Print shape information for debugging
48
- print (f"Converting { key } -> { new_key } " )
49
- print (f"Shape: { value .shape } " )
50
-
31
+ # key start with "model"-> add "base_" to the new key for base_model
32
+ new_key = f"base_{ key } "
51
33
new_state_dict [new_key ] = value
52
34
53
35
# Load the weights with strict=False to allow partial loading
@@ -136,5 +118,5 @@ def generate_ft(model, prompt, tokenizer, max_new_tokens, context_size=256, temp
136
118
print (f"Trainable parameters: { trainable_params :,} " )
137
119
138
120
# generate
139
- prompt = "Program a Flask API that will convert an image to grayscale ."
140
- print (generate_ft (model_ft , prompt , tokenizer , max_new_tokens = 256 ))
121
+ prompt = "Program a function that read through video frames and write it into a new video ."
122
+ print (generate_ft (model_ft , prompt , tokenizer , max_new_tokens = 512 ))
0 commit comments