Skip to content

Add embeddings and search file support #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions openai/api_resources/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ def generate(self, timeout=None, **params):

def search(self, **params):
return self.request("post", self.instance_url() + "/search", params)

def embeddings(self, **params):
return self.request("post", self.instance_url() + "/embeddings", params)
112 changes: 69 additions & 43 deletions openai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import signal
import sys
import warnings
from functools import partial

import openai
from openai.validators import (
apply_necessary_remediation,
apply_optional_remediation,
apply_validators,
get_search_validators,
get_validators,
read_any_format,
write_out_file,
write_out_search_file,
)


Expand Down Expand Up @@ -224,6 +227,41 @@ def list(cls, args):


class Search:
@classmethod
def prepare_data(cls, args):

sys.stdout.write("Analyzing...\n")
fname = args.file
auto_accept = args.quiet
purpose = args.purpose

optional_fields = ["metadata"]

if purpose == "classifications":
required_fields = ["text", "labels"]
else:
required_fields = ["text"]

df, remediation = read_any_format(
fname, fields=required_fields + optional_fields
)

if "metadata" not in df:
df["metadata"] = None

apply_necessary_remediation(None, remediation)
validators = get_search_validators(required_fields, optional_fields)

write_out_file_func = partial(
write_out_search_file,
purpose=purpose,
fields=required_fields + optional_fields,
)

apply_validators(
df, fname, remediation, validators, auto_accept, write_out_file_func
)

@classmethod
def create_alpha(cls, args):
resp = openai.Search.create_alpha(
Expand Down Expand Up @@ -436,49 +474,14 @@ def prepare_data(cls, args):

validators = get_validators()

optional_remediations = []
if remediation is not None:
optional_remediations.append(remediation)
for validator in validators:
remediation = validator(df)
if remediation is not None:
optional_remediations.append(remediation)
df = apply_necessary_remediation(df, remediation)

any_optional_or_necessary_remediations = any(
[
remediation
for remediation in optional_remediations
if remediation.optional_msg is not None
or remediation.necessary_msg is not None
]
apply_validators(
df,
fname,
remediation,
validators,
auto_accept,
write_out_file_func=write_out_file,
)
any_necessary_applied = any(
[
remediation
for remediation in optional_remediations
if remediation.necessary_msg is not None
]
)
any_optional_applied = False

if any_optional_or_necessary_remediations:
sys.stdout.write(
"\n\nBased on the analysis we will perform the following actions:\n"
)
for remediation in optional_remediations:
df, optional_applied = apply_optional_remediation(
df, remediation, auto_accept
)
any_optional_applied = any_optional_applied or optional_applied
else:
sys.stdout.write("\n\nNo remediations found.\n")

any_optional_or_necessary_applied = (
any_optional_applied or any_necessary_applied
)

write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept)


def tools_register(parser):
Expand Down Expand Up @@ -508,6 +511,29 @@ def help(args):
)
sub.set_defaults(func=FineTune.prepare_data)

sub = subparsers.add_parser("search.prepare_data")
sub.add_argument(
"-f",
"--file",
required=True,
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy pasta of prompt-completion from fine tuning?

"This should be the local file path.",
)
sub.add_argument(
"-p",
"--purpose",
help="Why are you uploading this file? (see https://beta.openai.com/docs/api-reference/ for purposes)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unintuitive to me that the Search purpose (which is what this is ostensibly under) is further split into different purposes. Maybe each of them should be a different tool instead? (Or do they all fall under the search umbrella, and how? We can chat outside this PR if i'm missing some context)

required=True,
)
sub.add_argument(
"-q",
"--quiet",
required=False,
action="store_true",
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
)
sub.set_defaults(func=Search.prepare_data)


def api_register(parser):
# Engine management
Expand Down
139 changes: 113 additions & 26 deletions openai/validators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import sys
import pandas as pd
import numpy as np
from typing import Any, Callable, NamedTuple, Optional

from typing import NamedTuple, Optional, Callable, Any
import numpy as np
import pandas as pd


class Remediation(NamedTuple):
Expand Down Expand Up @@ -70,7 +70,7 @@ def lower_case_column_creator(df):
)


def additional_column_validator(df):
def additional_column_validator(df, fields=["prompt", "completion"]):
"""
This validator will remove additional columns from the dataframe.
"""
Expand All @@ -79,9 +79,7 @@ def additional_column_validator(df):
immediate_msg = None
necessary_fn = None
if len(df.columns) > 2:
additional_columns = [
c for c in df.columns if c not in ["prompt", "completion"]
]
additional_columns = [c for c in df.columns if c not in fields]
warn_message = ""
for ac in additional_columns:
dups = [c for c in additional_columns if ac in c]
Expand All @@ -91,7 +89,7 @@ def additional_column_validator(df):
necessary_msg = f"Remove additional columns/keys: {additional_columns}"

def necessary_fn(x):
return x[["prompt", "completion"]]
return x[fields]

return Remediation(
name="additional_column",
Expand All @@ -101,50 +99,47 @@ def necessary_fn(x):
)


def non_empty_completion_validator(df):
def non_empty_field_validator(df, field="completion"):
"""
This validator will ensure that no completion is empty.
"""
necessary_msg = None
necessary_fn = None
immediate_msg = None

if (
df["completion"].apply(lambda x: x == "").any()
or df["completion"].isnull().any()
):
empty_rows = (df["completion"] == "") | (df["completion"].isnull())
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
empty_rows = (df[field] == "") | (df[field].isnull())
empty_indexes = df.reset_index().index[empty_rows].tolist()
immediate_msg = f"\n- `completion` column/key should not contain empty strings. These are rows: {empty_indexes}"
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"

def necessary_fn(x):
return x[x["completion"] != ""].dropna(subset=["completion"])
return x[x[field] != ""].dropna(subset=[field])

necessary_msg = f"Remove {len(empty_indexes)} rows with empty completions"
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
return Remediation(
name="empty_completion",
name=f"empty_{field}",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
)


def duplicated_rows_validator(df):
def duplicated_rows_validator(df, fields=["prompt", "completion"]):
"""
This validator will suggest to the user to remove duplicate rows if they exist.
"""
duplicated_rows = df.duplicated(subset=["prompt", "completion"])
duplicated_rows = df.duplicated(subset=fields)
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
immediate_msg = None
optional_msg = None
optional_fn = None

if len(duplicated_indexes) > 0:
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated prompt-completion pairs. These are rows: {duplicated_indexes}"
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"

def optional_fn(x):
return x.drop_duplicates(subset=["prompt", "completion"])
return x.drop_duplicates(subset=fields)

return Remediation(
name="duplicated_rows",
Expand Down Expand Up @@ -467,7 +462,7 @@ def lower_case(x):
)


def read_any_format(fname):
def read_any_format(fname, fields=["prompt", "completion"]):
"""
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
- for .xlsx it will read the first sheet
Expand Down Expand Up @@ -502,7 +497,7 @@ def read_any_format(fname):
content = f.read()
df = pd.DataFrame(
[["", line] for line in content.split("\n")],
columns=["prompt", "completion"],
columns=fields,
dtype=str,
)
if fname.lower().endswith("jsonl") or fname.lower().endswith("json"):
Expand Down Expand Up @@ -623,7 +618,7 @@ def get_outfnames(fname, split):
while True:
index_suffix = f" ({i})" if i > 0 else ""
candidate_fnames = [
fname.split(".")[0] + "_prepared" + suffix + index_suffix + ".jsonl"
os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl"
for suffix in suffixes
]
if not any(os.path.isfile(f) for f in candidate_fnames):
Expand Down Expand Up @@ -743,6 +738,30 @@ def write_out_file(df, fname, any_remediations, auto_accept):
sys.stdout.write("Aborting... did not write the file\n")


def write_out_search_file(df, fname, any_remediations, auto_accept, fields, purpose):
"""
This function will write out a dataframe to a file, if the user would like to proceed.
"""
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "

if not any_remediations:
sys.stdout.write(
f'\nYou can upload your file:\n> openai api files.create -f "{fname}" -p {purpose}'
)

elif accept_suggestion(input_text, auto_accept):
fnames = get_outfnames(fname, split=False)

assert len(fnames) == 1
df[fields].to_json(fnames[0], lines=True, orient="records", force_ascii=False)

sys.stdout.write(
f'\nWrote modified file to {fnames[0]}`\nFeel free to take a look!\n\nNow upload that file:\n> openai api files.create -f "{fnames[0]}" -p {purpose}'
)
else:
sys.stdout.write("Aborting... did not write the file\n")


def infer_task_type(df):
"""
Infer the likely fine-tuning task type from the data
Expand Down Expand Up @@ -787,7 +806,7 @@ def get_validators():
lambda x: necessary_column_validator(x, "prompt"),
lambda x: necessary_column_validator(x, "completion"),
additional_column_validator,
non_empty_completion_validator,
non_empty_field_validator,
format_inferrer_validator,
duplicated_rows_validator,
long_examples_validator,
Expand All @@ -799,3 +818,71 @@ def get_validators():
common_completion_suffix_validator,
completions_space_start_validator,
]


def get_search_validators(required_fields, optional_fields):
validators = [
lambda x: necessary_column_validator(x, field) for field in required_fields
]
validators += [
lambda x: non_empty_field_validator(x, field) for field in required_fields
]
validators += [lambda x: duplicated_rows_validator(x, required_fields)]
validators += [
lambda x: additional_column_validator(
x, fields=required_fields + optional_fields
),
]

return validators


def apply_validators(
df,
fname,
remediation,
validators,
auto_accept,
write_out_file_func,
):
optional_remediations = []
if remediation is not None:
optional_remediations.append(remediation)
for validator in validators:
remediation = validator(df)
if remediation is not None:
optional_remediations.append(remediation)
df = apply_necessary_remediation(df, remediation)

any_optional_or_necessary_remediations = any(
[
remediation
for remediation in optional_remediations
if remediation.optional_msg is not None
or remediation.necessary_msg is not None
]
)
any_necessary_applied = any(
[
remediation
for remediation in optional_remediations
if remediation.necessary_msg is not None
]
)
any_optional_applied = False

if any_optional_or_necessary_remediations:
sys.stdout.write(
"\n\nBased on the analysis we will perform the following actions:\n"
)
for remediation in optional_remediations:
df, optional_applied = apply_optional_remediation(
df, remediation, auto_accept
)
any_optional_applied = any_optional_applied or optional_applied
else:
sys.stdout.write("\n\nNo remediations found.\n")

any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied

write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)