Skip to content

Commit 36e001a

Browse files
authored
Add rev14 parameters and fixes. (#561)
* Add rev14 parameters Change-Id: I16f2b1f5820a6cf867b9abb04ffd5c6e6d2d947b * Fix flakey repr test Change-Id: I89bcf1494cf72c6ee28f2b52d0345cbb40859862 * format Change-Id: I81cff23e9ce20cc20b4a0632d557c71f536fd485 * Use client preview Change-Id: I2d8a4ee2e9e4b6e00a16a9dac1136a2fa18d7a28 * Fix tests Change-Id: If8fbbba1966aa42601adec877e60d851d4f03b72 * Fix tuned model tests Change-Id: I5ace9222954be7d903ebbdabab9efc663fa79174 * Fix tests Change-Id: Ifa610965c5d6c38123080a7e16416ac325418285 * format Change-Id: I15fd5701dd5c4200461a32c968fa19e375403a7e * pytype Change-Id: I08f74d08c4e93bbfdf353370b5dd57d8bf86a637 * pytype Change-Id: If81b86c176008cd9a99e3b879fbd3af086ec2235 * 3.9 tests Change-Id: I13e66016327aae0b0f3274e941bc615f379e5669
1 parent 4f42118 commit 36e001a

File tree

6 files changed

+82
-72
lines changed

6 files changed

+82
-72
lines changed

google/generativeai/types/generation_types.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,27 @@ class GenerationConfig:
144144
Note: The default value varies by model, see the
145145
`Model.top_k` attribute of the `Model` returned the
146146
`genai.get_model` function.
147-
147+
seed:
148+
Optional. Seed used in decoding. If not set, the request uses a randomly generated seed.
148149
response_mime_type:
149150
Optional. Output response mimetype of the generated candidate text.
150151
151152
Supported mimetype:
152153
`text/plain`: (default) Text output.
154+
`text/x-enum`: for use with a string-enum in `response_schema`
153155
`application/json`: JSON response in the candidates.
154156
155157
response_schema:
156158
Optional. Specifies the format of the JSON requested if response_mime_type is
157159
`application/json`.
160+
presence_penalty:
161+
Optional.
162+
frequency_penalty:
163+
Optional.
164+
response_logprobs:
165+
Optional. If true, export the `logprobs` results in response.
166+
logprobs:
167+
Optional. Number of candidates of log probabilities to return at each step of decoding.
158168
"""
159169

160170
candidate_count: int | None = None
@@ -163,8 +173,13 @@ class GenerationConfig:
163173
temperature: float | None = None
164174
top_p: float | None = None
165175
top_k: int | None = None
176+
seed: int | None = None
166177
response_mime_type: str | None = None
167178
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
179+
presence_penalty: float | None = None
180+
frequency_penalty: float | None = None
181+
response_logprobs: bool | None = None
182+
logprobs: int | None = None
168183

169184

170185
GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
@@ -306,6 +321,7 @@ def _join_code_execution_result(result_1, result_2):
306321

307322

308323
def _join_candidates(candidates: Iterable[protos.Candidate]):
324+
"""Joins stream chunks of a single candidate."""
309325
candidates = tuple(candidates)
310326

311327
index = candidates[0].index # These should all be the same.
@@ -321,6 +337,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]):
321337

322338

323339
def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
340+
"""Joins stream chunks where each chunk is a list of candidate chunks."""
324341
# Assuming that is a candidate ends, it is no longer returned in the list of
325342
# candidates and that's why candidates have an index
326343
candidates = collections.defaultdict(list)
@@ -344,10 +361,15 @@ def _join_prompt_feedbacks(
344361

345362
def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
346363
chunks = tuple(chunks)
364+
if "usage_metadata" in chunks[-1]:
365+
usage_metadata = chunks[-1].usage_metadata
366+
else:
367+
usage_metadata = None
368+
347369
return protos.GenerateContentResponse(
348370
candidates=_join_candidate_lists(c.candidates for c in chunks),
349371
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
350-
usage_metadata=chunks[-1].usage_metadata,
372+
usage_metadata=usage_metadata,
351373
)
352374

353375

@@ -541,7 +563,8 @@ def __str__(self) -> str:
541563
_result = _result.replace("\n", "\n ")
542564

543565
if self._error:
544-
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"
566+
567+
_error = f",\nerror={repr(self._error)}"
545568
else:
546569
_error = ""
547570

google/generativeai/types/model_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str):
143143

144144
def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
145145
if isinstance(tuned_model, protos.TunedModel):
146-
tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error
146+
tuned_model = type(tuned_model).to_dict(
147+
tuned_model, including_default_value_fields=False
148+
) # pytype: disable=attribute-error
147149
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))
148150

149151
base_model = tuned_model.pop("base_model", None)
@@ -195,6 +197,7 @@ class TunedModel:
195197
create_time: datetime.datetime | None = None
196198
update_time: datetime.datetime | None = None
197199
tuning_task: TuningTask | None = None
200+
reader_project_numbers: list[int] | None = None
198201

199202
@property
200203
def permissions(self) -> permission_types.Permissions:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_version():
4242
release_status = "Development Status :: 5 - Production/Stable"
4343

4444
dependencies = [
45-
"google-ai-generativelanguage==0.6.9",
45+
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz",
4646
"google-api-core",
4747
"google-api-python-client",
4848
"google-auth>=2.15.0", # 2.15 adds API key auth support

tests/test_files.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from __future__ import annotations
1516

1617
from google.generativeai.types import file_types
1718

1819
import collections
1920
import datetime
2021
import os
21-
from typing import Iterable, Union
22+
from typing import Iterable, Sequence
2223
import pathlib
2324

2425
import google
@@ -37,12 +38,13 @@ def __init__(self, test):
3738

3839
def create_file(
3940
self,
40-
path: Union[str, pathlib.Path, os.PathLike],
41+
path: str | pathlib.Path | os.PathLike,
4142
*,
42-
mime_type: Union[str, None] = None,
43-
name: Union[str, None] = None,
44-
display_name: Union[str, None] = None,
43+
mime_type: str | None = None,
44+
name: str | None = None,
45+
display_name: str | None = None,
4546
resumable: bool = True,
47+
metadata: Sequence[tuple[str, str]] = (),
4648
) -> protos.File:
4749
self.observed_requests.append(
4850
dict(

tests/test_generation.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import inspect
17+
import json
218
import string
319
import textwrap
420
from typing_extensions import TypedDict
@@ -22,6 +38,8 @@ class Person(TypedDict):
2238

2339

2440
class UnitTests(parameterized.TestCase):
41+
maxDiff = None
42+
2543
@parameterized.named_parameters(
2644
[
2745
"protos.GenerationConfig",
@@ -416,24 +434,16 @@ def test_join_prompt_feedbacks(self):
416434
],
417435
"role": "assistant",
418436
},
419-
"citation_metadata": {"citation_sources": []},
420437
"index": 0,
421-
"finish_reason": 0,
422-
"safety_ratings": [],
423-
"token_count": 0,
424-
"grounding_attributions": [],
438+
"citation_metadata": {},
425439
},
426440
{
427441
"content": {
428442
"parts": [{"text": "Tell me a story about a magic backpack"}],
429443
"role": "assistant",
430444
},
431445
"index": 1,
432-
"citation_metadata": {"citation_sources": []},
433-
"finish_reason": 0,
434-
"safety_ratings": [],
435-
"token_count": 0,
436-
"grounding_attributions": [],
446+
"citation_metadata": {},
437447
},
438448
{
439449
"content": {
@@ -458,17 +468,16 @@ def test_join_prompt_feedbacks(self):
458468
},
459469
]
460470
},
461-
"finish_reason": 0,
462-
"safety_ratings": [],
463-
"token_count": 0,
464-
"grounding_attributions": [],
465471
},
466472
]
467473

468474
def test_join_candidates(self):
469475
candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS]
470476
result = generation_types._join_candidate_lists(candidate_lists)
471-
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result])
477+
self.assertEqual(
478+
self.MERGED_CANDIDATES,
479+
[type(r).to_dict(r, including_default_value_fields=False) for r in result],
480+
)
472481

473482
def test_join_chunks(self):
474483
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
@@ -480,6 +489,10 @@ def test_join_chunks(self):
480489
],
481490
)
482491

492+
chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(
493+
prompt_token_count=5
494+
)
495+
483496
result = generation_types._join_chunks(chunks)
484497

485498
expected = protos.GenerateContentResponse(
@@ -495,10 +508,18 @@ def test_join_chunks(self):
495508
}
496509
],
497510
},
511+
"usage_metadata": {"prompt_token_count": 5},
498512
},
499513
)
500514

501-
self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected))
515+
expected = json.dumps(
516+
type(expected).to_dict(expected, including_default_value_fields=False), indent=4
517+
)
518+
result = json.dumps(
519+
type(result).to_dict(result, including_default_value_fields=False), indent=4
520+
)
521+
522+
self.assertEqual(expected, result)
502523

503524
def test_generate_content_response_iterator_end_to_end(self):
504525
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]

tests/test_generative_models.py

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,7 @@ def test_repr_for_streaming_start_to_finish(self):
935935
"citation_metadata": {}
936936
}
937937
],
938-
"prompt_feedback": {},
939-
"usage_metadata": {}
938+
"prompt_feedback": {}
940939
}),
941940
)"""
942941
)
@@ -964,8 +963,7 @@ def test_repr_for_streaming_start_to_finish(self):
964963
"citation_metadata": {}
965964
}
966965
],
967-
"prompt_feedback": {},
968-
"usage_metadata": {}
966+
"prompt_feedback": {}
969967
}),
970968
)"""
971969
)
@@ -998,10 +996,10 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self):
998996
}
999997
}),
1000998
),
1001-
error=<BlockedPromptException> prompt_feedback {
999+
error=BlockedPromptException(prompt_feedback {
10021000
block_reason: SAFETY
10031001
}
1004-
"""
1002+
)"""
10051003
)
10061004
self.assertEqual(expected, result)
10071005

@@ -1056,11 +1054,10 @@ def no_throw():
10561054
"citation_metadata": {}
10571055
}
10581056
],
1059-
"prompt_feedback": {},
1060-
"usage_metadata": {}
1057+
"prompt_feedback": {}
10611058
}),
10621059
),
1063-
error=<ValueError> """
1060+
error=ValueError()"""
10641061
)
10651062
self.assertEqual(expected, result)
10661063

@@ -1095,43 +1092,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self):
10951092
response = chat.send_message("hello2", stream=True)
10961093

10971094
result = repr(response)
1098-
expected = textwrap.dedent(
1099-
"""\
1100-
response:
1101-
GenerateContentResponse(
1102-
done=True,
1103-
iterator=None,
1104-
result=protos.GenerateContentResponse({
1105-
"candidates": [
1106-
{
1107-
"content": {
1108-
"parts": [
1109-
{
1110-
"text": "abc"
1111-
}
1112-
]
1113-
},
1114-
"finish_reason": "SAFETY",
1115-
"index": 0,
1116-
"citation_metadata": {}
1117-
}
1118-
],
1119-
"prompt_feedback": {},
1120-
"usage_metadata": {}
1121-
}),
1122-
),
1123-
error=<StopCandidateException> content {
1124-
parts {
1125-
text: "abc"
1126-
}
1127-
}
1128-
finish_reason: SAFETY
1129-
index: 0
1130-
citation_metadata {
1131-
}
1132-
"""
1133-
)
1134-
self.assertEqual(expected, result)
1095+
self.assertIn("StopCandidateException", result)
11351096

11361097
def test_repr_for_multi_turn_chat(self):
11371098
# Multi turn chat

0 commit comments

Comments
 (0)