diff --git a/openai/api_resources/engine.py b/openai/api_resources/engine.py index bc0b244f55..d63bf2159d 100644 --- a/openai/api_resources/engine.py +++ b/openai/api_resources/engine.py @@ -30,3 +30,6 @@ def generate(self, timeout=None, **params): def search(self, **params): return self.request("post", self.instance_url() + "/search", params) + + def embeddings(self, **params): + return self.request("post", self.instance_url() + "/embeddings", params) diff --git a/openai/cli.py b/openai/cli.py index e625d1bdf6..12f22662a8 100644 --- a/openai/cli.py +++ b/openai/cli.py @@ -3,14 +3,17 @@ import signal import sys import warnings +from functools import partial import openai from openai.validators import ( apply_necessary_remediation, - apply_optional_remediation, + apply_validators, + get_search_validators, get_validators, read_any_format, write_out_file, + write_out_search_file, ) @@ -224,6 +227,41 @@ def list(cls, args): class Search: + @classmethod + def prepare_data(cls, args): + + sys.stdout.write("Analyzing...\n") + fname = args.file + auto_accept = args.quiet + purpose = args.purpose + + optional_fields = ["metadata"] + + if purpose == "classifications": + required_fields = ["text", "labels"] + else: + required_fields = ["text"] + + df, remediation = read_any_format( + fname, fields=required_fields + optional_fields + ) + + if "metadata" not in df: + df["metadata"] = None + + apply_necessary_remediation(None, remediation) + validators = get_search_validators(required_fields, optional_fields) + + write_out_file_func = partial( + write_out_search_file, + purpose=purpose, + fields=required_fields + optional_fields, + ) + + apply_validators( + df, fname, remediation, validators, auto_accept, write_out_file_func + ) + @classmethod def create_alpha(cls, args): resp = openai.Search.create_alpha( @@ -436,49 +474,14 @@ def prepare_data(cls, args): validators = get_validators() - optional_remediations = [] - if remediation is not None: - optional_remediations.append(remediation) - for validator in validators: - remediation = validator(df) - if remediation is not None: - optional_remediations.append(remediation) - df = apply_necessary_remediation(df, remediation) - - any_optional_or_necessary_remediations = any( - [ - remediation - for remediation in optional_remediations - if remediation.optional_msg is not None - or remediation.necessary_msg is not None - ] + apply_validators( + df, + fname, + remediation, + validators, + auto_accept, + write_out_file_func=write_out_file, ) - any_necessary_applied = any( - [ - remediation - for remediation in optional_remediations - if remediation.necessary_msg is not None - ] - ) - any_optional_applied = False - - if any_optional_or_necessary_remediations: - sys.stdout.write( - "\n\nBased on the analysis we will perform the following actions:\n" - ) - for remediation in optional_remediations: - df, optional_applied = apply_optional_remediation( - df, remediation, auto_accept - ) - any_optional_applied = any_optional_applied or optional_applied - else: - sys.stdout.write("\n\nNo remediations found.\n") - - any_optional_or_necessary_applied = ( - any_optional_applied or any_necessary_applied - ) - - write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept) def tools_register(parser): @@ -508,6 +511,29 @@ def help(args): ) sub.set_defaults(func=FineTune.prepare_data) + sub = subparsers.add_parser("search.prepare_data") + sub.add_argument( + "-f", + "--file", + required=True, + help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed." + "This should be the local file path.", + ) + sub.add_argument( + "-p", + "--purpose", + help="Why are you uploading this file? (see https://beta.openai.com/docs/api-reference/ for purposes)", + required=True, + ) + sub.add_argument( + "-q", + "--quiet", + required=False, + action="store_true", + help="Auto accepts all suggestions, without asking for user input. To be used within scripts.", + ) + sub.set_defaults(func=Search.prepare_data) + def api_register(parser): # Engine management diff --git a/openai/validators.py b/openai/validators.py index 181aacb7dd..3cefd5a24c 100644 --- a/openai/validators.py +++ b/openai/validators.py @@ -1,9 +1,9 @@ import os import sys -import pandas as pd -import numpy as np +from typing import Any, Callable, NamedTuple, Optional -from typing import NamedTuple, Optional, Callable, Any +import numpy as np +import pandas as pd class Remediation(NamedTuple): @@ -70,7 +70,7 @@ def lower_case_column_creator(df): ) -def additional_column_validator(df): +def additional_column_validator(df, fields=["prompt", "completion"]): """ This validator will remove additional columns from the dataframe. """ @@ -79,9 +79,7 @@ def additional_column_validator(df): immediate_msg = None necessary_fn = None if len(df.columns) > 2: - additional_columns = [ - c for c in df.columns if c not in ["prompt", "completion"] - ] + additional_columns = [c for c in df.columns if c not in fields] warn_message = "" for ac in additional_columns: dups = [c for c in additional_columns if ac in c] @@ -91,7 +89,7 @@ def additional_column_validator(df): necessary_msg = f"Remove additional columns/keys: {additional_columns}" def necessary_fn(x): - return x[["prompt", "completion"]] + return x[fields] return Remediation( name="additional_column", @@ -101,7 +99,7 @@ def necessary_fn(x): ) -def non_empty_completion_validator(df): +def non_empty_field_validator(df, field="completion"): """ This validator will ensure that no completion is empty. """ @@ -109,42 +107,39 @@ def non_empty_completion_validator(df): necessary_fn = None immediate_msg = None - if ( - df["completion"].apply(lambda x: x == "").any() - or df["completion"].isnull().any() - ): - empty_rows = (df["completion"] == "") | (df["completion"].isnull()) + if df[field].apply(lambda x: x == "").any() or df[field].isnull().any(): + empty_rows = (df[field] == "") | (df[field].isnull()) empty_indexes = df.reset_index().index[empty_rows].tolist() - immediate_msg = f"\n- `completion` column/key should not contain empty strings. These are rows: {empty_indexes}" + immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}" def necessary_fn(x): - return x[x["completion"] != ""].dropna(subset=["completion"]) + return x[x[field] != ""].dropna(subset=[field]) - necessary_msg = f"Remove {len(empty_indexes)} rows with empty completions" + necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s" return Remediation( - name="empty_completion", + name=f"empty_{field}", immediate_msg=immediate_msg, necessary_msg=necessary_msg, necessary_fn=necessary_fn, ) -def duplicated_rows_validator(df): +def duplicated_rows_validator(df, fields=["prompt", "completion"]): """ This validator will suggest to the user to remove duplicate rows if they exist. """ - duplicated_rows = df.duplicated(subset=["prompt", "completion"]) + duplicated_rows = df.duplicated(subset=fields) duplicated_indexes = df.reset_index().index[duplicated_rows].tolist() immediate_msg = None optional_msg = None optional_fn = None if len(duplicated_indexes) > 0: - immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated prompt-completion pairs. These are rows: {duplicated_indexes}" + immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}" optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows" def optional_fn(x): - return x.drop_duplicates(subset=["prompt", "completion"]) + return x.drop_duplicates(subset=fields) return Remediation( name="duplicated_rows", @@ -467,7 +462,7 @@ def lower_case(x): ) -def read_any_format(fname): +def read_any_format(fname, fields=["prompt", "completion"]): """ This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas. - for .xlsx it will read the first sheet @@ -502,7 +497,7 @@ def read_any_format(fname): content = f.read() df = pd.DataFrame( [["", line] for line in content.split("\n")], - columns=["prompt", "completion"], + columns=fields, dtype=str, ) if fname.lower().endswith("jsonl") or fname.lower().endswith("json"): @@ -623,7 +618,7 @@ def get_outfnames(fname, split): while True: index_suffix = f" ({i})" if i > 0 else "" candidate_fnames = [ - fname.split(".")[0] + "_prepared" + suffix + index_suffix + ".jsonl" + os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl" for suffix in suffixes ] if not any(os.path.isfile(f) for f in candidate_fnames): @@ -743,6 +738,30 @@ def write_out_file(df, fname, any_remediations, auto_accept): sys.stdout.write("Aborting... did not write the file\n") +def write_out_search_file(df, fname, any_remediations, auto_accept, fields, purpose): + """ + This function will write out a dataframe to a file, if the user would like to proceed. + """ + input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " + + if not any_remediations: + sys.stdout.write( + f'\nYou can upload your file:\n> openai api files.create -f "{fname}" -p {purpose}' + ) + + elif accept_suggestion(input_text, auto_accept): + fnames = get_outfnames(fname, split=False) + + assert len(fnames) == 1 + df[fields].to_json(fnames[0], lines=True, orient="records", force_ascii=False) + + sys.stdout.write( + f'\nWrote modified file to {fnames[0]}`\nFeel free to take a look!\n\nNow upload that file:\n> openai api files.create -f "{fnames[0]}" -p {purpose}' + ) + else: + sys.stdout.write("Aborting... did not write the file\n") + + def infer_task_type(df): """ Infer the likely fine-tuning task type from the data @@ -787,7 +806,7 @@ def get_validators(): lambda x: necessary_column_validator(x, "prompt"), lambda x: necessary_column_validator(x, "completion"), additional_column_validator, - non_empty_completion_validator, + non_empty_field_validator, format_inferrer_validator, duplicated_rows_validator, long_examples_validator, @@ -799,3 +818,71 @@ def get_validators(): common_completion_suffix_validator, completions_space_start_validator, ] + + +def get_search_validators(required_fields, optional_fields): + validators = [ + lambda x: necessary_column_validator(x, field) for field in required_fields + ] + validators += [ + lambda x: non_empty_field_validator(x, field) for field in required_fields + ] + validators += [lambda x: duplicated_rows_validator(x, required_fields)] + validators += [ + lambda x: additional_column_validator( + x, fields=required_fields + optional_fields + ), + ] + + return validators + + +def apply_validators( + df, + fname, + remediation, + validators, + auto_accept, + write_out_file_func, +): + optional_remediations = [] + if remediation is not None: + optional_remediations.append(remediation) + for validator in validators: + remediation = validator(df) + if remediation is not None: + optional_remediations.append(remediation) + df = apply_necessary_remediation(df, remediation) + + any_optional_or_necessary_remediations = any( + [ + remediation + for remediation in optional_remediations + if remediation.optional_msg is not None + or remediation.necessary_msg is not None + ] + ) + any_necessary_applied = any( + [ + remediation + for remediation in optional_remediations + if remediation.necessary_msg is not None + ] + ) + any_optional_applied = False + + if any_optional_or_necessary_remediations: + sys.stdout.write( + "\n\nBased on the analysis we will perform the following actions:\n" + ) + for remediation in optional_remediations: + df, optional_applied = apply_optional_remediation( + df, remediation, auto_accept + ) + any_optional_applied = any_optional_applied or optional_applied + else: + sys.stdout.write("\n\nNo remediations found.\n") + + any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied + + write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)