Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c79fefc

Browse files
authoredJul 29, 2021
Updates to prepare_data function (#29)
* update documentation links to point to the website * Fix encoding * Add rough time estimator based on historical stats * Fix train_test split naming logic; add quiet mode for running inside scripts * Add a finetuning step by step example for a classification use case. * add classification params if train and valid set; add length_validator
1 parent 7afc3b5 commit c79fefc

File tree

5 files changed

+901
-68
lines changed

5 files changed

+901
-68
lines changed
 

‎README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ openai api completions.create -e ada -p "Hello world"
6868

6969
## Requirements
7070

71-
- Python 3.6+
71+
- Python 3.7+
7272

7373
In general we want to support the versions of Python that our
7474
customers are using, so if you run into issues with any version

‎examples/finetuning/finetuning-classification.ipynb

Lines changed: 731 additions & 0 deletions
Large diffs are not rendered by default.

‎openai/cli.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def create(cls, args):
322322
def get(cls, args):
323323
resp = openai.FineTune.retrieve(id=args.id)
324324
print(resp)
325-
print(resp["result_files"][0])
326325

327326
@classmethod
328327
def results(cls, args):
@@ -417,6 +416,7 @@ def prepare_data(cls, args):
417416

418417
sys.stdout.write("Analyzing...\n")
419418
fname = args.file
419+
auto_accept = args.quiet
420420
df, remediation = read_any_format(fname)
421421
apply_necessary_remediation(None, remediation)
422422

@@ -439,18 +439,32 @@ def prepare_data(cls, args):
439439
or remediation.necessary_msg is not None
440440
]
441441
)
442+
any_necessary_applied = any(
443+
[
444+
remediation
445+
for remediation in optional_remediations
446+
if remediation.necessary_msg is not None
447+
]
448+
)
449+
any_optional_applied = False
442450

443451
if any_optional_or_necessary_remediations:
444452
sys.stdout.write(
445453
"\n\nBased on the analysis we will perform the following actions:\n"
446454
)
447-
448455
for remediation in optional_remediations:
449-
df = apply_optional_remediation(df, remediation)
456+
df, optional_applied = apply_optional_remediation(
457+
df, remediation, auto_accept
458+
)
459+
any_optional_applied = any_optional_applied or optional_applied
450460
else:
451461
sys.stdout.write("\n\nNo remediations found.\n")
452462

453-
write_out_file(df, fname, any_optional_or_necessary_remediations)
463+
any_optional_or_necessary_applied = (
464+
any_optional_applied or any_necessary_applied
465+
)
466+
467+
write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept)
454468

455469

456470
def tools_register(parser):
@@ -471,6 +485,13 @@ def help(args):
471485
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
472486
"This should be the local file path.",
473487
)
488+
sub.add_argument(
489+
"-q",
490+
"--quiet",
491+
required=False,
492+
action="store_true",
493+
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
494+
)
474495
sub.set_defaults(func=FineTune.prepare_data)
475496

476497

‎openai/validators.py

Lines changed: 143 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,36 @@ def optional_fn(x):
153153
)
154154

155155

156+
def long_examples_validator(df):
157+
"""
158+
This validator will suggest to the user to remove examples that are too long.
159+
"""
160+
immediate_msg = None
161+
optional_msg = None
162+
optional_fn = None
163+
164+
ft_type = infer_task_type(df)
165+
if ft_type != "open-ended generation":
166+
long_examples = df.apply(
167+
lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1
168+
)
169+
long_indexes = df.reset_index().index[long_examples].tolist()
170+
171+
if len(long_indexes) > 0:
172+
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
173+
optional_msg = f"Remove {len(long_indexes)} long examples"
174+
175+
def optional_fn(x):
176+
return x.drop(long_indexes)
177+
178+
return Remediation(
179+
name="long_examples",
180+
immediate_msg=immediate_msg,
181+
optional_msg=optional_msg,
182+
optional_fn=optional_fn,
183+
)
184+
185+
156186
def common_prompt_suffix_validator(df):
157187
"""
158188
This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
@@ -210,7 +240,7 @@ def add_suffix(x, suffix):
210240
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
211241

212242
else:
213-
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See `Fine Tuning How to Guide` for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
243+
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
214244

215245
if common_suffix == "":
216246
optional_msg = (
@@ -361,7 +391,7 @@ def add_suffix(x, suffix):
361391
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
362392

363393
else:
364-
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See `Fine Tuning How to Guide` for more detail and examples."
394+
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
365395

366396
if common_suffix == "":
367397
optional_msg = (
@@ -396,7 +426,7 @@ def add_space_start(x):
396426
immediate_msg = None
397427

398428
if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
399-
immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See `Fine Tuning How to Guide` for more details"
429+
immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
400430
optional_msg = "Add a whitespace character to the beginning of the completion"
401431
optional_fn = add_space_start
402432
return Remediation(
@@ -430,7 +460,7 @@ def lower_case(x):
430460
if count_upper * 2 > count_lower:
431461
return Remediation(
432462
name="lower_case",
433-
immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See `Fine Tuning How to Guide` for more details",
463+
immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
434464
optional_msg=f"Lowercase all your data in column/key `{column}`",
435465
optional_fn=lower_case,
436466
)
@@ -534,19 +564,81 @@ def apply_necessary_remediation(df, remediation):
534564
return df
535565

536566

537-
def apply_optional_remediation(df, remediation):
567+
def accept_suggestion(input_text, auto_accept):
568+
sys.stdout.write(input_text)
569+
if auto_accept:
570+
sys.stdout.write("Y")
571+
return True
572+
return input().lower() != "n"
573+
574+
575+
def apply_optional_remediation(df, remediation, auto_accept):
538576
"""
539577
This function will apply an optional remediation to a dataframe, based on the user input.
540578
"""
579+
optional_applied = False
580+
input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
541581
if remediation.optional_msg is not None:
542-
if input(f"- [Recommended] {remediation.optional_msg} [Y/n]: ").lower() != "n":
582+
if accept_suggestion(input_text, auto_accept):
543583
df = remediation.optional_fn(df)
584+
optional_applied = True
544585
if remediation.necessary_msg is not None:
545586
sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
546-
return df
587+
return df, optional_applied
588+
589+
590+
def estimate_fine_tuning_time(df):
591+
"""
592+
Estimate the time it'll take to fine-tune the dataset
593+
"""
594+
ft_format = infer_task_type(df)
595+
expected_time = 1.0
596+
if ft_format == "classification":
597+
num_examples = len(df)
598+
expected_time = num_examples * 1.44
599+
else:
600+
size = df.memory_usage(index=True).sum()
601+
expected_time = size * 0.0515
602+
603+
def format_time(time):
604+
if time < 60:
605+
return f"{round(time, 2)} seconds"
606+
elif time < 3600:
607+
return f"{round(time / 60, 2)} minutes"
608+
elif time < 86400:
609+
return f"{round(time / 3600, 2)} hours"
610+
else:
611+
return f"{round(time / 86400, 2)} days"
612+
613+
time_string = format_time(expected_time + 140)
614+
sys.stdout.write(
615+
f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
616+
)
617+
618+
619+
def get_outfnames(fname, split):
620+
suffixes = ["_train", "_valid"] if split else [""]
621+
i = 0
622+
while True:
623+
index_suffix = f" ({i})" if i > 0 else ""
624+
candidate_fnames = [
625+
fname.split(".")[0] + "_prepared" + suffix + index_suffix + ".jsonl"
626+
for suffix in suffixes
627+
]
628+
if not any(os.path.isfile(f) for f in candidate_fnames):
629+
return candidate_fnames
630+
i += 1
631+
632+
633+
def get_classification_hyperparams(df):
634+
n_classes = df.completion.nunique()
635+
pos_class = None
636+
if n_classes == 2:
637+
pos_class = df.completion.value_counts().index[0]
638+
return n_classes, pos_class
547639

548640

549-
def write_out_file(df, fname, any_remediations):
641+
def write_out_file(df, fname, any_remediations, auto_accept):
550642
"""
551643
This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
552644
For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
@@ -556,16 +648,16 @@ def write_out_file(df, fname, any_remediations):
556648
common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
557649

558650
split = False
651+
input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
559652
if ft_format == "classification":
560-
if (
561-
input(
562-
"- [Recommended] Would you like to split into training and validation set? [Y/n]: "
563-
)
564-
!= "n"
565-
):
653+
if accept_suggestion(input_text, auto_accept):
566654
split = True
567655

568-
packing_param = " --no_packing" if ft_format == "classification" else ""
656+
classification_params = ""
657+
if ft_format == "classification" or (
658+
ft_format == "conditional generation" and len(df) < 1000
659+
):
660+
classification_params = " --no_packing"
569661
common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
570662
common_completion_suffix_new_line_handled = common_completion_suffix.replace(
571663
"\n", "\\n"
@@ -576,67 +668,55 @@ def write_out_file(df, fname, any_remediations):
576668
else ""
577669
)
578670

671+
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
672+
579673
if not any_remediations:
580674
sys.stdout.write(
581-
f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{packing_param}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
675+
f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{classification_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
582676
)
677+
estimate_fine_tuning_time(df)
678+
679+
elif accept_suggestion(input_text, auto_accept):
680+
fnames = get_outfnames(fname, split)
681+
if split:
682+
assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
683+
MAX_VALID_EXAMPLES = 1000
684+
n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
685+
df_train = df.sample(n=n_train, random_state=42)
686+
df_valid = df.drop(df_train.index)
687+
df_train[["prompt", "completion"]].to_json(
688+
fnames[0], lines=True, orient="records", force_ascii=False
689+
)
690+
df_valid[["prompt", "completion"]].to_json(
691+
fnames[1], lines=True, orient="records", force_ascii=False
692+
)
583693

584-
elif (
585-
input(
586-
"\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
587-
).lower()
588-
!= "n"
589-
):
590-
591-
suffixes = ["_train", "_valid"] if split else [""]
592-
outfnames = []
593-
indices = None
594-
for suffix in suffixes:
595-
out_fname = fname.split(".")[0] + "_prepared" + suffix + ".jsonl"
596-
597-
# check if file already exists, and if it does, add a number to the end
598-
i = 0
599-
while True:
600-
to_continue = False
601-
# in case of train and test, make sure that the numbers will match
602-
for suf in suffixes:
603-
out_fname = (
604-
fname.split(".")[0] + "_prepared" + suf + f" ({i})" + ".jsonl"
605-
)
606-
if i == 0:
607-
out_fname = fname.split(".")[0] + "_prepared" + suf + ".jsonl"
608-
i += 1
609-
if os.path.isfile(out_fname):
610-
to_continue = True
611-
if to_continue:
612-
continue
613-
break
614-
615-
outfnames.append(out_fname)
616-
if suffix == "_train":
617-
MAX_VALID_EXAMPLES = 1000
618-
n = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
619-
df_out = df.sample(n=n, random_state=42)
620-
indices = df_out.index
621-
if suffix == "_valid":
622-
df_out = df.drop(indices)
623-
if suffix == "":
624-
df_out = df
625-
df_out[["prompt", "completion"]].to_json(
626-
out_fname, lines=True, orient="records"
694+
n_classes, pos_class = get_classification_hyperparams(df)
695+
classification_params += " --compute_classification_metrics"
696+
if n_classes == 2:
697+
classification_params += (
698+
f' --classification_positive_class "{pos_class}"'
699+
)
700+
else:
701+
classification_params += f" --classification_n_classes {n_classes}"
702+
else:
703+
assert len(fnames) == 1
704+
df[["prompt", "completion"]].to_json(
705+
fnames[0], lines=True, orient="records", force_ascii=False
627706
)
628707

629708
# Add -v VALID_FILE if we split the file into train / valid
630-
files_string = ("s" if split else "") + " to `" + ("` and `".join(outfnames))
631-
valid_string = f' -v "{outfnames[1]}"' if split else ""
709+
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
710+
valid_string = f' -v "{fnames[1]}"' if split else ""
632711
separator_reminder = (
633712
""
634713
if len(common_prompt_suffix_new_line_handled) == 0
635714
else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
636715
)
637716
sys.stdout.write(
638-
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\n{separator_reminder}{optional_ending_string}\n'
717+
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{classification_params}\n\n{separator_reminder}{optional_ending_string}\n'
639718
)
719+
estimate_fine_tuning_time(df)
640720
else:
641721
sys.stdout.write("Aborting... did not write the file\n")
642722

@@ -688,6 +768,7 @@ def get_validators():
688768
non_empty_completion_validator,
689769
format_inferrer_validator,
690770
duplicated_rows_validator,
771+
long_examples_validator,
691772
lambda x: lower_case_validator(x, "prompt"),
692773
lambda x: lower_case_validator(x, "completion"),
693774
common_prompt_suffix_validator,

‎openai/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.10.1"
1+
VERSION = "0.10.2"

0 commit comments

Comments
 (0)
Please sign in to comment.