Skip to content

Commit 2eb5e82

Browse files
authored
Update dataset construction code to handle more input types. (#79)
* Allow dicts and pd.DataFrames as tuning datasets * fix type hints. * Support csv and json files, csv URLs. * docs * add TODO * Allow json-urls and streaming decoding of CSVs.
1 parent 923d372 commit 2eb5e82

File tree

7 files changed

+207
-22
lines changed

7 files changed

+207
-22
lines changed

google/generativeai/models.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,18 @@ def create_tuned_model(
244244
epoch_count: int | None = None,
245245
batch_size: int | None = None,
246246
learning_rate: float | None = None,
247+
input_key: str = "text_input",
248+
output_key: str = "output",
247249
client: glm.ModelServiceClient | None = None,
248250
) -> operations.CreateTunedModelOperation:
249251
"""Launches a tuning job to create a TunedModel.
250252
251253
Since tuning a model can take significant time, this API doesn't wait for the tuning to complete.
252-
Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the status
253-
of the tuning job, or wait for it to complete, and check the result.
254+
Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the
255+
status of the tuning job, or wait for it to complete, and check the result.
254256
255-
After the job completes you can either find the resulting `TunedModel` object in `Operation.result()`
256-
or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
257+
After the job completes you can either find the resulting `TunedModel` object in
258+
`Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
257259
258260
```
259261
my_id = "my-tuned-model-id"
@@ -275,6 +277,16 @@ def create_tuned_model(
275277
*`glm.TuningExample`,
276278
* {'text_input': text_input, 'output': output} dicts, or
277279
* `(text_input, output)` tuples.
280+
* A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which
281+
columns to use as the input/output
282+
* A csv file (will be read with `pd.read_csv` and handles as a `Mapping`
283+
above). This can be:
284+
* A local path as a `str` or `pathlib.Path`.
285+
* A url for a csv file.
286+
* The url of a Google Sheets file.
287+
* A JSON file - Its contents will be handled either as an `Iterable` or `Mapping`
288+
above. This can be:
289+
* A local path as a `str` or `pathlib.Path`.
278290
id: The model identifier, used to refer to the model in the API
279291
`tunedModels/{id}`. Must be unique.
280292
display_name: A human-readable name for display.
@@ -308,7 +320,9 @@ def create_tuned_model(
308320
else:
309321
ValueError(f"Not understood: `{source_model=}`")
310322

311-
training_data = model_types.encode_tuning_data(training_data)
323+
training_data = model_types.encode_tuning_data(
324+
training_data, input_key=input_key, output_key=output_key
325+
)
312326

313327
hyperparameters = glm.Hyperparameters(
314328
epoch_count=epoch_count,

google/generativeai/types/model_types.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
"""Type definitions for the models service."""
1616
from __future__ import annotations
1717

18+
from collections.abc import Mapping
19+
import csv
1820
import dataclasses
1921
import datetime
22+
import json
23+
import pathlib
2024
import re
2125
from typing import Any, Iterable, TypedDict, Union
26+
import urllib.request
2227

2328
import google.ai.generativelanguage as glm
2429
from google.generativeai import string_utils
@@ -72,17 +77,18 @@ class Model:
7277
"""A dataclass representation of a `glm.Model`.
7378
7479
Attributes:
75-
name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming convention of:
76-
"{base_model_id}-{version}". For example: `models/chat-bison-001`.
80+
name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming
81+
convention of: "{base_model_id}-{version}". For example: `models/chat-bison-001`.
7782
base_model_id: The base name of the model. For example: `chat-bison`.
7883
version: The major version number of the model. For example: `001`.
79-
display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up to 128 characters
80-
long and can consist of any UTF-8 characters.
84+
display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up
85+
to 128 characters long and can consist of any UTF-8 characters.
8186
description: A short description of the model.
8287
input_token_limit: Maximum number of input tokens allowed for this model.
8388
output_token_limit: Maximum number of output tokens available for this model.
84-
supported_generation_methods: lists which methods are supported by the model. The method names are defined as
85-
Pascal case strings, such as `generateMessage` which correspond to API methods.
89+
supported_generation_methods: lists which methods are supported by the model. The method
90+
names are defined as Pascal case strings, such as `generateMessage` which correspond to
91+
API methods.
8692
"""
8793

8894
name: str
@@ -187,28 +193,100 @@ class TuningExampleDict(TypedDict):
187193
output: str
188194

189195

190-
TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str]]
196+
TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]]
197+
198+
# TODO(markdaoust): gs:// URLS? File-type argument for files without extension?
191199
TuningDataOptions = Union[
192-
glm.Dataset, Iterable[TuningExampleOptions]
193-
] # TODO(markdaoust): csv, json, pandas, np
200+
pathlib.Path,
201+
str,
202+
glm.Dataset,
203+
Mapping[str, Iterable[str]],
204+
Iterable[TuningExampleOptions],
205+
]
194206

195207

196-
def encode_tuning_data(data: TuningDataOptions) -> glm.Dataset:
208+
def encode_tuning_data(
209+
data: TuningDataOptions, input_key="text_input", output_key="output"
210+
) -> glm.Dataset:
197211
if isinstance(data, glm.Dataset):
198212
return data
199213

214+
if isinstance(data, str):
215+
# Strings are either URLs or system paths.
216+
if re.match("^\w+://\S+$", data):
217+
data = _normalize_url(data)
218+
else:
219+
# Normalize system paths to use pathlib
220+
data = pathlib.Path(data)
221+
222+
if isinstance(data, (str, pathlib.Path)):
223+
if isinstance(data, str):
224+
f = urllib.request.urlopen(data)
225+
# csv needs strings, json does not.
226+
content = (line.decode("utf-8") for line in f)
227+
else:
228+
f = data.open("r")
229+
content = f
230+
231+
if str(data).lower().endswith(".json"):
232+
with f:
233+
data = json.load(f)
234+
else:
235+
with f:
236+
data = csv.DictReader(content)
237+
return _convert_iterable(data, input_key, output_key)
238+
239+
if hasattr(data, "keys"):
240+
return _convert_dict(data, input_key, output_key)
241+
else:
242+
return _convert_iterable(data, input_key, output_key)
243+
244+
245+
def _normalize_url(url: str) -> str:
246+
sheet_base = "https://docs.google.com/spreadsheets"
247+
if url.startswith(sheet_base):
248+
# Normalize google-sheets URLs to download the csv.
249+
match = re.match(f"{sheet_base}/d/[^/]+", url)
250+
if match is None:
251+
raise ValueError("Incomplete Google Sheets URL: {data}")
252+
url = f"{match.group(0)}/export?format=csv"
253+
return url
254+
255+
256+
def _convert_dict(data, input_key, output_key):
257+
new_data = list()
258+
259+
try:
260+
inputs = data[input_key]
261+
except KeyError as e:
262+
raise KeyError(f'input_key is "{input_key}", but data has keys: {sorted(data.keys())}')
263+
264+
try:
265+
outputs = data[output_key]
266+
except KeyError as e:
267+
raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}')
268+
269+
for i, o in zip(inputs, outputs):
270+
new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)}))
271+
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))
272+
273+
274+
def _convert_iterable(data, input_key, output_key):
200275
new_data = list()
201276
for example in data:
202-
example = encode_tuning_example(example)
277+
example = encode_tuning_example(example, input_key, output_key)
203278
new_data.append(example)
204279
return glm.Dataset(examples=glm.TuningExamples(examples=new_data))
205280

206281

207-
def encode_tuning_example(example: TuningExampleOptions):
208-
if isinstance(example, tuple):
209-
example = glm.TuningExample(text_input=example[0], output=example[1])
210-
else: # dict or glm.TuningExample
211-
example = glm.TuningExample(example)
282+
def encode_tuning_example(example: TuningExampleOptions, input_key, output_key):
283+
if isinstance(example, glm.TuningExample):
284+
return example
285+
elif isinstance(example, (tuple, list)):
286+
a, b = example
287+
example = glm.TuningExample(text_input=a, output=b)
288+
else: # dict
289+
example = glm.TuningExample(text_input=example[input_key], output=example[output_key])
212290
return example
213291

214292

tests/test.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
text_input,output
2+
a,1
3+
b,2
4+
c,3

tests/test1.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[
2+
{"text_input": "a", "output": "1"},
3+
{"text_input": "b", "output": "2"},
4+
{"text_input": "c", "output": "3"}
5+
]

tests/test2.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}

tests/test3.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[
2+
["a","1"],
3+
["b","2"],
4+
["c","3"]
5+
]

tests/test_models.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import datetime
1717
import dataclasses
18+
import pathlib
1819
import pytz
1920
from typing import Any, Union
2021
import unittest
@@ -29,7 +30,10 @@
2930
from google.generativeai import models
3031
from google.generativeai import client
3132
from google.generativeai.types import model_types
32-
from google.protobuf import field_mask_pb2
33+
34+
import pandas as pd
35+
36+
HERE = pathlib.Path(__file__).parent
3337

3438

3539
class UnitTests(parameterized.TestCase):
@@ -385,6 +389,80 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source):
385389
"models/swim-fish-000",
386390
)
387391

392+
@parameterized.named_parameters(
393+
[
394+
"glm",
395+
glm.Dataset(
396+
examples=glm.TuningExamples(
397+
examples=[
398+
{"text_input": "a", "output": "1"},
399+
{"text_input": "b", "output": "2"},
400+
{"text_input": "c", "output": "3"},
401+
]
402+
)
403+
),
404+
],
405+
[
406+
"list",
407+
[
408+
("a", "1"),
409+
{"text_input": "b", "output": "2"},
410+
glm.TuningExample({"text_input": "c", "output": "3"}),
411+
],
412+
],
413+
["dict", {"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}],
414+
[
415+
"dict_custom_keys",
416+
{"my_inputs": ["a", "b", "c"], "my_outputs": ["1", "2", "3"]},
417+
"my_inputs",
418+
"my_outputs",
419+
],
420+
[
421+
"pd.DataFrame",
422+
pd.DataFrame(
423+
[
424+
{"text_input": "a", "output": "1"},
425+
{"text_input": "b", "output": "2"},
426+
{"text_input": "c", "output": "3"},
427+
]
428+
),
429+
],
430+
["csv-path-string", str(HERE / "test.csv")],
431+
["csv-path", HERE / "test.csv"],
432+
["json-file-1", HERE / "test1.json"],
433+
["json-file-2", HERE / "test2.json"],
434+
["json-file-3", HERE / "test3.json"],
435+
[
436+
"json-url",
437+
"https://storage.googleapis.com/generativeai-downloads/data/test1.json",
438+
],
439+
[
440+
"csv-url",
441+
"https://storage.googleapis.com/generativeai-downloads/data/test.csv",
442+
],
443+
[
444+
"sheet-share",
445+
"https://docs.google.com/spreadsheets/d/1OffcVSqN6X-RYdWLGccDF3KtnKoIpS7O_9cZbicKK4A/edit?usp=sharing",
446+
],
447+
[
448+
"sheet-export-csv",
449+
"https://docs.google.com/spreadsheets/d/1OffcVSqN6X-RYdWLGccDF3KtnKoIpS7O_9cZbicKK4A/export?format=csv",
450+
],
451+
)
452+
def test_create_dataset(self, data, ik="text_input", ok="output"):
453+
ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok)
454+
455+
expect = glm.Dataset(
456+
examples=glm.TuningExamples(
457+
examples=[
458+
{"text_input": "a", "output": "1"},
459+
{"text_input": "b", "output": "2"},
460+
{"text_input": "c", "output": "3"},
461+
]
462+
)
463+
)
464+
self.assertEqual(expect, ds)
465+
388466

389467
if __name__ == "__main__":
390468
absltest.main()

0 commit comments

Comments
 (0)