1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15-
16- import os
15+ import copy
1716import unittest
1817import unittest .mock as mock
1918
2221from google .generativeai import text as text_service
2322from google .generativeai import client
2423from google .generativeai .types import safety_types
24+ from google .generativeai .types import model_types
2525from absl .testing import absltest
2626from absl .testing import parameterized
2727
@@ -31,8 +31,9 @@ def setUp(self):
3131 self .client = unittest .mock .MagicMock ()
3232
3333 client ._client_manager .text_client = self .client
34+ client ._client_manager .model_client = self .client
3435
35- self .observed_request = None
36+ self .observed_requests = []
3637
3738 self .responses = {}
3839
@@ -45,23 +46,37 @@ def add_client_method(f):
4546 def generate_text (
4647 request : glm .GenerateTextRequest ,
4748 ) -> glm .GenerateTextResponse :
48- self .observed_request = request
49+ self .observed_requests . append ( request )
4950 return self .responses ["generate_text" ]
5051
5152 @add_client_method
5253 def embed_text (
5354 request : glm .EmbedTextRequest ,
5455 ) -> glm .EmbedTextResponse :
55- self .observed_request = request
56+ self .observed_requests . append ( request )
5657 return self .responses ["embed_text" ]
5758
5859 @add_client_method
5960 def batch_embed_text (
6061 request : glm .EmbedTextRequest ,
6162 ) -> glm .EmbedTextResponse :
62- self .observed_request = request
63+ self .observed_requests . append ( request )
6364 return self .responses ["batch_embed_text" ]
6465
66+ @add_client_method
67+ def count_text_tokens (
68+ request : glm .CountTextTokensRequest ,
69+ ) -> glm .CountTextTokensResponse :
70+ self .observed_requests .append (request )
71+ return self .responses ["count_text_tokens" ]
72+
73+ @add_client_method
74+ def get_tuned_model (name ) -> glm .TunedModel :
75+ request = glm .GetTunedModelRequest (name = name )
76+ self .observed_requests .append (request )
77+ response = copy .copy (self .responses ["get_tuned_model" ])
78+ return response
79+
6580 @parameterized .named_parameters (
6681 [
6782 dict (testcase_name = "string" , prompt = "Hello how are" ),
@@ -99,7 +114,7 @@ def test_generate_embeddings(self, model, text):
99114 emb = text_service .generate_embeddings (model = model , text = text )
100115
101116 self .assertIsInstance (emb , dict )
102- self .assertEqual (self .observed_request , glm .EmbedTextRequest (model = model , text = text ))
117+ self .assertEqual (self .observed_requests [ - 1 ] , glm .EmbedTextRequest (model = model , text = text ))
103118 self .assertIsInstance (emb ["embedding" ][0 ], float )
104119
105120 @parameterized .named_parameters (
@@ -123,8 +138,7 @@ def test_generate_embeddings_batch(self, model, text):
123138
124139 self .assertIsInstance (emb , dict )
125140 self .assertEqual (
126- self .observed_request ,
127- glm .BatchEmbedTextRequest (model = model , texts = text ),
141+ self .observed_requests [- 1 ], glm .BatchEmbedTextRequest (model = model , texts = text )
128142 )
129143 self .assertIsInstance (emb ["embedding" ][0 ], list )
130144
@@ -160,7 +174,7 @@ def test_generate_response(self, *, prompt, **kwargs):
160174 complete = text_service .generate_text (prompt = prompt , ** kwargs )
161175
162176 self .assertEqual (
163- self .observed_request ,
177+ self .observed_requests [ - 1 ] ,
164178 glm .GenerateTextRequest (
165179 model = "models/text-bison-001" , prompt = glm .TextPrompt (text = prompt ), ** kwargs
166180 ),
@@ -188,15 +202,15 @@ def test_stop_string(self):
188202 complete = text_service .generate_text (prompt = "Hello" , stop_sequences = "stop" )
189203
190204 self .assertEqual (
191- self .observed_request ,
205+ self .observed_requests [ - 1 ] ,
192206 glm .GenerateTextRequest (
193207 model = "models/text-bison-001" ,
194208 prompt = glm .TextPrompt (text = "Hello" ),
195209 stop_sequences = ["stop" ],
196210 ),
197211 )
198212 # Just make sure it made it into the request object.
199- self .assertEqual (self .observed_request .stop_sequences , ["stop" ])
213+ self .assertEqual (self .observed_requests [ - 1 ] .stop_sequences , ["stop" ])
200214
201215 @parameterized .named_parameters (
202216 [
@@ -251,7 +265,7 @@ def test_safety_settings(self, safety_settings):
251265 )
252266
253267 self .assertEqual (
254- self .observed_request .safety_settings [0 ].category ,
268+ self .observed_requests [ - 1 ] .safety_settings [0 ].category ,
255269 safety_types .HarmCategory .HARM_CATEGORY_MEDICAL ,
256270 )
257271
@@ -367,6 +381,72 @@ def test_candidate_citations(self):
367381 6 ,
368382 )
369383
384+ @parameterized .named_parameters (
385+ [
386+ dict (testcase_name = "base-name" , model = "models/text-bison-001" ),
387+ dict (testcase_name = "tuned-name" , model = "tunedModels/bipedal-pangolin-001" ),
388+ dict (
389+ testcase_name = "model" ,
390+ model = model_types .Model (
391+ name = "models/text-bison-001" ,
392+ base_model_id = "text-bison-001" ,
393+ version = "001" ,
394+ display_name = "🦬" ,
395+ description = "🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬" ,
396+ input_token_limit = 8000 ,
397+ output_token_limit = 4000 ,
398+ supported_generation_methods = ["GenerateText" ],
399+ ),
400+ ),
401+ dict (
402+ testcase_name = "tuned_model" ,
403+ model = model_types .TunedModel (
404+ name = "tunedModels/bipedal-pangolin-001" ,
405+ base_model = "models/text-bison-001" ,
406+ ),
407+ ),
408+ dict (
409+ testcase_name = "glm_model" ,
410+ model = glm .Model (
411+ name = "models/text-bison-001" ,
412+ ),
413+ ),
414+ dict (
415+ testcase_name = "glm_tuned_model" ,
416+ model = glm .TunedModel (
417+ name = "tunedModels/bipedal-pangolin-001" ,
418+ base_model = "models/text-bison-001" ,
419+ ),
420+ ),
421+ dict (
422+ testcase_name = "glm_tuned_model_nested" ,
423+ model = glm .TunedModel (
424+ name = "tunedModels/bipedal-pangolin-002" ,
425+ tuned_model_source = {
426+ "tuned_model" : "tunedModels/bipedal-pangolin-002" ,
427+ "base_model" : "models/text-bison-001" ,
428+ },
429+ ),
430+ ),
431+ ]
432+ )
433+ def test_count_message_tokens (self , model ):
434+ self .responses ["get_tuned_model" ] = glm .TunedModel (
435+ name = "tunedModels/bipedal-pangolin-001" , base_model = "models/text-bison-001"
436+ )
437+ self .responses ["count_text_tokens" ] = glm .CountTextTokensResponse (token_count = 7 )
438+
439+ response = text_service .count_text_tokens (model , "Tell me a story about a magic backpack." )
440+ self .assertEqual ({"token_count" : 7 }, response )
441+
442+ should_look_up_model = isinstance (model , str ) and model .startswith ("tunedModels/" )
443+ if should_look_up_model :
444+ self .assertLen (self .observed_requests , 2 )
445+ self .assertEqual (
446+ self .observed_requests [0 ],
447+ glm .GetTunedModelRequest (name = "tunedModels/bipedal-pangolin-001" ),
448+ )
449+
370450
371451if __name__ == "__main__" :
372452 absltest .main ()
0 commit comments