Skip to content

Commit 8852715

Browse files
author
TAI DO
committed
clean code
1 parent 6a8d2df commit 8852715

File tree

1 file changed

+9
-27
lines changed

1 file changed

+9
-27
lines changed

utils_for_app.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
def load_fine_tune_model(base_model_id, saved_weights):
66
# Load tokenizer and base model
7+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
89
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
9-
base_model.to("cpu")
10+
base_model.to(device)
1011

1112
# Create LoRA config - make sure these parameters match your training configuration
1213
peft_config = LoraConfig(
@@ -15,39 +16,20 @@ def load_fine_tune_model(base_model_id, saved_weights):
1516
lora_dropout=0.05,
1617
bias="none",
1718
task_type="CAUSAL_LM",
18-
target_modules=[
19-
'q_proj',
20-
'k_proj',
21-
'v_proj',
22-
],
19+
target_modules=['q_proj','k_proj','v_proj'],
2320
)
2421

2522
# Initialize PeftModel
2623
lora_model = PeftModel(base_model, peft_config)
2724

2825
# Load the saved weights
29-
state_dict = torch.load(saved_weights,map_location=torch.device('cpu'))
30-
31-
# Check the keys in the state dict
32-
print("Original state dict keys:", state_dict.keys())
33-
26+
state_dict = torch.load(saved_weights,map_location=device)
27+
3428
# Create new state dict with correct prefixes and structure
3529
new_state_dict = {}
3630
for key, value in state_dict.items():
37-
if 'model.model.' in key:
38-
# If the key already has model.model., just add base_
39-
new_key = f"base_{key}"
40-
elif 'model.' in key:
41-
# If the key has only model., add base_model.
42-
new_key = f"base_{key}"
43-
else:
44-
# If the key has neither, add the full prefix
45-
new_key = f"base_model.model.{key}"
46-
47-
# Print shape information for debugging
48-
print(f"Converting {key} -> {new_key}")
49-
print(f"Shape: {value.shape}")
50-
31+
# key start with "model"-> add "base_" to the new key for base_model
32+
new_key = f"base_{key}"
5133
new_state_dict[new_key] = value
5234

5335
# Load the weights with strict=False to allow partial loading
@@ -136,5 +118,5 @@ def generate_ft(model, prompt, tokenizer, max_new_tokens, context_size=256, temp
136118
print(f"Trainable parameters: {trainable_params:,}")
137119

138120
# generate
139-
prompt = "Program a Flask API that will convert an image to grayscale."
140-
print(generate_ft(model_ft, prompt, tokenizer, max_new_tokens=256))
121+
prompt = "Program a function that read through video frames and write it into a new video."
122+
print(generate_ft(model_ft, prompt, tokenizer, max_new_tokens=512))

0 commit comments

Comments
 (0)