|
15 | 15 | """Type definitions for the models service."""
|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
| 18 | +from collections.abc import Mapping |
| 19 | +import csv |
18 | 20 | import dataclasses
|
19 | 21 | import datetime
|
| 22 | +import json |
| 23 | +import pathlib |
20 | 24 | import re
|
21 | 25 | from typing import Any, Iterable, TypedDict, Union
|
| 26 | +import urllib.request |
22 | 27 |
|
23 | 28 | import google.ai.generativelanguage as glm
|
24 | 29 | from google.generativeai import string_utils
|
@@ -72,17 +77,18 @@ class Model:
|
72 | 77 | """A dataclass representation of a `glm.Model`.
|
73 | 78 |
|
74 | 79 | Attributes:
|
75 |
| - name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming convention of: |
76 |
| - "{base_model_id}-{version}". For example: `models/chat-bison-001`. |
| 80 | + name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming |
| 81 | + convention of: "{base_model_id}-{version}". For example: `models/chat-bison-001`. |
77 | 82 | base_model_id: The base name of the model. For example: `chat-bison`.
|
78 | 83 | version: The major version number of the model. For example: `001`.
|
79 |
| - display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up to 128 characters |
80 |
| - long and can consist of any UTF-8 characters. |
| 84 | + display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up |
| 85 | + to 128 characters long and can consist of any UTF-8 characters. |
81 | 86 | description: A short description of the model.
|
82 | 87 | input_token_limit: Maximum number of input tokens allowed for this model.
|
83 | 88 | output_token_limit: Maximum number of output tokens available for this model.
|
84 |
| - supported_generation_methods: lists which methods are supported by the model. The method names are defined as |
85 |
| - Pascal case strings, such as `generateMessage` which correspond to API methods. |
| 89 | + supported_generation_methods: lists which methods are supported by the model. The method |
| 90 | + names are defined as Pascal case strings, such as `generateMessage` which correspond to |
| 91 | + API methods. |
86 | 92 | """
|
87 | 93 |
|
88 | 94 | name: str
|
@@ -187,28 +193,100 @@ class TuningExampleDict(TypedDict):
|
187 | 193 | output: str
|
188 | 194 |
|
189 | 195 |
|
190 |
| -TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str]] |
| 196 | +TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]] |
| 197 | + |
| 198 | +# TODO(markdaoust): gs:// URLS? File-type argument for files without extension? |
191 | 199 | TuningDataOptions = Union[
|
192 |
| - glm.Dataset, Iterable[TuningExampleOptions] |
193 |
| -] # TODO(markdaoust): csv, json, pandas, np |
| 200 | + pathlib.Path, |
| 201 | + str, |
| 202 | + glm.Dataset, |
| 203 | + Mapping[str, Iterable[str]], |
| 204 | + Iterable[TuningExampleOptions], |
| 205 | +] |
194 | 206 |
|
195 | 207 |
|
196 |
| -def encode_tuning_data(data: TuningDataOptions) -> glm.Dataset: |
| 208 | +def encode_tuning_data( |
| 209 | + data: TuningDataOptions, input_key="text_input", output_key="output" |
| 210 | +) -> glm.Dataset: |
197 | 211 | if isinstance(data, glm.Dataset):
|
198 | 212 | return data
|
199 | 213 |
|
| 214 | + if isinstance(data, str): |
| 215 | + # Strings are either URLs or system paths. |
| 216 | + if re.match("^\w+://\S+$", data): |
| 217 | + data = _normalize_url(data) |
| 218 | + else: |
| 219 | + # Normalize system paths to use pathlib |
| 220 | + data = pathlib.Path(data) |
| 221 | + |
| 222 | + if isinstance(data, (str, pathlib.Path)): |
| 223 | + if isinstance(data, str): |
| 224 | + f = urllib.request.urlopen(data) |
| 225 | + # csv needs strings, json does not. |
| 226 | + content = (line.decode("utf-8") for line in f) |
| 227 | + else: |
| 228 | + f = data.open("r") |
| 229 | + content = f |
| 230 | + |
| 231 | + if str(data).lower().endswith(".json"): |
| 232 | + with f: |
| 233 | + data = json.load(f) |
| 234 | + else: |
| 235 | + with f: |
| 236 | + data = csv.DictReader(content) |
| 237 | + return _convert_iterable(data, input_key, output_key) |
| 238 | + |
| 239 | + if hasattr(data, "keys"): |
| 240 | + return _convert_dict(data, input_key, output_key) |
| 241 | + else: |
| 242 | + return _convert_iterable(data, input_key, output_key) |
| 243 | + |
| 244 | + |
| 245 | +def _normalize_url(url: str) -> str: |
| 246 | + sheet_base = "https://docs.google.com/spreadsheets" |
| 247 | + if url.startswith(sheet_base): |
| 248 | + # Normalize google-sheets URLs to download the csv. |
| 249 | + match = re.match(f"{sheet_base}/d/[^/]+", url) |
| 250 | + if match is None: |
| 251 | + raise ValueError("Incomplete Google Sheets URL: {data}") |
| 252 | + url = f"{match.group(0)}/export?format=csv" |
| 253 | + return url |
| 254 | + |
| 255 | + |
| 256 | +def _convert_dict(data, input_key, output_key): |
| 257 | + new_data = list() |
| 258 | + |
| 259 | + try: |
| 260 | + inputs = data[input_key] |
| 261 | + except KeyError as e: |
| 262 | + raise KeyError(f'input_key is "{input_key}", but data has keys: {sorted(data.keys())}') |
| 263 | + |
| 264 | + try: |
| 265 | + outputs = data[output_key] |
| 266 | + except KeyError as e: |
| 267 | + raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}') |
| 268 | + |
| 269 | + for i, o in zip(inputs, outputs): |
| 270 | + new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)})) |
| 271 | + return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) |
| 272 | + |
| 273 | + |
| 274 | +def _convert_iterable(data, input_key, output_key): |
200 | 275 | new_data = list()
|
201 | 276 | for example in data:
|
202 |
| - example = encode_tuning_example(example) |
| 277 | + example = encode_tuning_example(example, input_key, output_key) |
203 | 278 | new_data.append(example)
|
204 | 279 | return glm.Dataset(examples=glm.TuningExamples(examples=new_data))
|
205 | 280 |
|
206 | 281 |
|
207 |
| -def encode_tuning_example(example: TuningExampleOptions): |
208 |
| - if isinstance(example, tuple): |
209 |
| - example = glm.TuningExample(text_input=example[0], output=example[1]) |
210 |
| - else: # dict or glm.TuningExample |
211 |
| - example = glm.TuningExample(example) |
| 282 | +def encode_tuning_example(example: TuningExampleOptions, input_key, output_key): |
| 283 | + if isinstance(example, glm.TuningExample): |
| 284 | + return example |
| 285 | + elif isinstance(example, (tuple, list)): |
| 286 | + a, b = example |
| 287 | + example = glm.TuningExample(text_input=a, output=b) |
| 288 | + else: # dict |
| 289 | + example = glm.TuningExample(text_input=example[input_key], output=example[output_key]) |
212 | 290 | return example
|
213 | 291 |
|
214 | 292 |
|
|
0 commit comments