@@ -18,17 +18,104 @@ import { ModelParams } from "../types";
18
18
import { GenerativeModel , GoogleGenerativeAI } from "./gen-ai" ;
19
19
import { expect } from "chai" ;
20
20
21
+ const fakeContents = [ { role : "user" , parts : [ { text : "hello" } ] } ] ;
22
+
23
+ const fakeCachedContent = {
24
+ model : "my-model" ,
25
+ name : "mycachename" ,
26
+ contents : fakeContents ,
27
+ } ;
28
+
21
29
describe ( "GoogleGenerativeAI" , ( ) => {
22
- it ( "genGenerativeInstance throws if no model is provided" , ( ) => {
30
+ it ( "getGenerativeModel throws if no model is provided" , ( ) => {
23
31
const genAI = new GoogleGenerativeAI ( "apikey" ) ;
24
32
expect ( ( ) => genAI . getGenerativeModel ( { } as ModelParams ) ) . to . throw (
25
33
"Must provide a model name" ,
26
34
) ;
27
35
} ) ;
28
- it ( "genGenerativeInstance gets a GenerativeModel" , ( ) => {
36
+ it ( "getGenerativeModel gets a GenerativeModel" , ( ) => {
29
37
const genAI = new GoogleGenerativeAI ( "apikey" ) ;
30
38
const genModel = genAI . getGenerativeModel ( { model : "my-model" } ) ;
31
39
expect ( genModel ) . to . be . an . instanceOf ( GenerativeModel ) ;
32
40
expect ( genModel . model ) . to . equal ( "models/my-model" ) ;
33
41
} ) ;
42
+ it ( "getGenerativeModelFromCachedContent gets a GenerativeModel" , ( ) => {
43
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
44
+ const genModel =
45
+ genAI . getGenerativeModelFromCachedContent ( fakeCachedContent ) ;
46
+ expect ( genModel ) . to . be . an . instanceOf ( GenerativeModel ) ;
47
+ expect ( genModel . model ) . to . equal ( "models/my-model" ) ;
48
+ expect ( genModel . cachedContent ) . to . eql ( fakeCachedContent ) ;
49
+ } ) ;
50
+ it ( "getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams" , ( ) => {
51
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
52
+ const genModel = genAI . getGenerativeModelFromCachedContent (
53
+ fakeCachedContent ,
54
+ { generationConfig : { temperature : 0 } } ,
55
+ ) ;
56
+ expect ( genModel ) . to . be . an . instanceOf ( GenerativeModel ) ;
57
+ expect ( genModel . model ) . to . equal ( "models/my-model" ) ;
58
+ expect ( genModel . generationConfig . temperature ) . to . equal ( 0 ) ;
59
+ expect ( genModel . cachedContent ) . to . eql ( fakeCachedContent ) ;
60
+ } ) ;
61
+ it ( "getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams with overlapping keys" , ( ) => {
62
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
63
+ const genModel = genAI . getGenerativeModelFromCachedContent (
64
+ fakeCachedContent ,
65
+ { model : "my-model" , generationConfig : { temperature : 0 } } ,
66
+ ) ;
67
+ expect ( genModel ) . to . be . an . instanceOf ( GenerativeModel ) ;
68
+ expect ( genModel . model ) . to . equal ( "models/my-model" ) ;
69
+ expect ( genModel . generationConfig . temperature ) . to . equal ( 0 ) ;
70
+ expect ( genModel . cachedContent ) . to . eql ( fakeCachedContent ) ;
71
+ } ) ;
72
+ it ( "getGenerativeModelFromCachedContent throws if no name" , ( ) => {
73
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
74
+ expect ( ( ) =>
75
+ genAI . getGenerativeModelFromCachedContent ( {
76
+ model : "my-model" ,
77
+ contents : fakeContents ,
78
+ } ) ,
79
+ ) . to . throw ( "Cached content must contain a `name` field." ) ;
80
+ } ) ;
81
+ it ( "getGenerativeModelFromCachedContent throws if no model" , ( ) => {
82
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
83
+ expect ( ( ) =>
84
+ genAI . getGenerativeModelFromCachedContent ( {
85
+ name : "cachename" ,
86
+ contents : fakeContents ,
87
+ } ) ,
88
+ ) . to . throw ( "Cached content must contain a `model` field." ) ;
89
+ } ) ;
90
+ it ( "getGenerativeModelFromCachedContent throws if mismatched model" , ( ) => {
91
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
92
+ expect ( ( ) =>
93
+ genAI . getGenerativeModelFromCachedContent (
94
+ {
95
+ name : "cachename" ,
96
+ model : "my-model" ,
97
+ contents : fakeContents ,
98
+ } ,
99
+ { model : "your-model" } ,
100
+ ) ,
101
+ ) . to . throw (
102
+ `Different value for "model" specified in modelParams (your-model) and cachedContent (my-model)` ,
103
+ ) ;
104
+ } ) ;
105
+ it ( "getGenerativeModelFromCachedContent throws if mismatched systemInstruction" , ( ) => {
106
+ const genAI = new GoogleGenerativeAI ( "apikey" ) ;
107
+ expect ( ( ) =>
108
+ genAI . getGenerativeModelFromCachedContent (
109
+ {
110
+ name : "cachename" ,
111
+ model : "my-model" ,
112
+ contents : fakeContents ,
113
+ systemInstruction : "hi" ,
114
+ } ,
115
+ { model : "models/my-model" , systemInstruction : "yo" } ,
116
+ ) ,
117
+ ) . to . throw (
118
+ `Different value for "systemInstruction" specified in modelParams (yo) and cachedContent (hi)` ,
119
+ ) ;
120
+ } ) ;
34
121
} ) ;
0 commit comments