Skip to content

Commit fc008a1

Browse files
authored
Add ability to set modelParams on getGenerativeModelFromCachedContent() (#254)
1 parent ce49f34 commit fc008a1

File tree

6 files changed

+132
-5
lines changed

6 files changed

+132
-5
lines changed

.changeset/tame-lizards-kiss.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@google/generative-ai": minor
3+
---
4+
5+
Add ability to set modelParams (generationConfig, safetySettings) on getGenerativeModelFromCachedContent().

common/api-review/generative-ai.api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ export class GoogleGenerativeAI {
481481
// (undocumented)
482482
apiKey: string;
483483
getGenerativeModel(modelParams: ModelParams, requestOptions?: RequestOptions): GenerativeModel;
484-
getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions?: RequestOptions): GenerativeModel;
484+
getGenerativeModelFromCachedContent(cachedContent: CachedContent, modelParams?: Partial<ModelParams>, requestOptions?: RequestOptions): GenerativeModel;
485485
}
486486

487487
// @public

docs/reference/main/generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from pr
99
**Signature:**
1010

1111
```typescript
12-
getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions?: RequestOptions): GenerativeModel;
12+
getGenerativeModelFromCachedContent(cachedContent: CachedContent, modelParams?: Partial<ModelParams>, requestOptions?: RequestOptions): GenerativeModel;
1313
```
1414

1515
## Parameters
1616

1717
| Parameter | Type | Description |
1818
| --- | --- | --- |
1919
| cachedContent | [CachedContent](./generative-ai.cachedcontent.md) | |
20+
| modelParams | Partial&lt;[ModelParams](./generative-ai.modelparams.md)<!-- -->&gt; | _(Optional)_ |
2021
| requestOptions | [RequestOptions](./generative-ai.requestoptions.md) | _(Optional)_ |
2122

2223
**Returns:**

docs/reference/main/generative-ai.googlegenerativeai.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ export declare class GoogleGenerativeAI
2929
| Method | Modifiers | Description |
3030
| --- | --- | --- |
3131
| [getGenerativeModel(modelParams, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodel.md) | | Gets a [GenerativeModel](./generative-ai.generativemodel.md) instance for the provided model name. |
32-
| [getGenerativeModelFromCachedContent(cachedContent, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md) | | Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from provided content cache. |
32+
| [getGenerativeModelFromCachedContent(cachedContent, modelParams, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md) | | Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from provided content cache. |
3333

src/gen-ai.test.ts

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,104 @@ import { ModelParams } from "../types";
1818
import { GenerativeModel, GoogleGenerativeAI } from "./gen-ai";
1919
import { expect } from "chai";
2020

21+
const fakeContents = [{ role: "user", parts: [{ text: "hello" }] }];
22+
23+
const fakeCachedContent = {
24+
model: "my-model",
25+
name: "mycachename",
26+
contents: fakeContents,
27+
};
28+
2129
describe("GoogleGenerativeAI", () => {
22-
it("genGenerativeInstance throws if no model is provided", () => {
30+
it("getGenerativeModel throws if no model is provided", () => {
2331
const genAI = new GoogleGenerativeAI("apikey");
2432
expect(() => genAI.getGenerativeModel({} as ModelParams)).to.throw(
2533
"Must provide a model name",
2634
);
2735
});
28-
it("genGenerativeInstance gets a GenerativeModel", () => {
36+
it("getGenerativeModel gets a GenerativeModel", () => {
2937
const genAI = new GoogleGenerativeAI("apikey");
3038
const genModel = genAI.getGenerativeModel({ model: "my-model" });
3139
expect(genModel).to.be.an.instanceOf(GenerativeModel);
3240
expect(genModel.model).to.equal("models/my-model");
3341
});
42+
it("getGenerativeModelFromCachedContent gets a GenerativeModel", () => {
43+
const genAI = new GoogleGenerativeAI("apikey");
44+
const genModel =
45+
genAI.getGenerativeModelFromCachedContent(fakeCachedContent);
46+
expect(genModel).to.be.an.instanceOf(GenerativeModel);
47+
expect(genModel.model).to.equal("models/my-model");
48+
expect(genModel.cachedContent).to.eql(fakeCachedContent);
49+
});
50+
it("getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams", () => {
51+
const genAI = new GoogleGenerativeAI("apikey");
52+
const genModel = genAI.getGenerativeModelFromCachedContent(
53+
fakeCachedContent,
54+
{ generationConfig: { temperature: 0 } },
55+
);
56+
expect(genModel).to.be.an.instanceOf(GenerativeModel);
57+
expect(genModel.model).to.equal("models/my-model");
58+
expect(genModel.generationConfig.temperature).to.equal(0);
59+
expect(genModel.cachedContent).to.eql(fakeCachedContent);
60+
});
61+
it("getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams with overlapping keys", () => {
62+
const genAI = new GoogleGenerativeAI("apikey");
63+
const genModel = genAI.getGenerativeModelFromCachedContent(
64+
fakeCachedContent,
65+
{ model: "my-model", generationConfig: { temperature: 0 } },
66+
);
67+
expect(genModel).to.be.an.instanceOf(GenerativeModel);
68+
expect(genModel.model).to.equal("models/my-model");
69+
expect(genModel.generationConfig.temperature).to.equal(0);
70+
expect(genModel.cachedContent).to.eql(fakeCachedContent);
71+
});
72+
it("getGenerativeModelFromCachedContent throws if no name", () => {
73+
const genAI = new GoogleGenerativeAI("apikey");
74+
expect(() =>
75+
genAI.getGenerativeModelFromCachedContent({
76+
model: "my-model",
77+
contents: fakeContents,
78+
}),
79+
).to.throw("Cached content must contain a `name` field.");
80+
});
81+
it("getGenerativeModelFromCachedContent throws if no model", () => {
82+
const genAI = new GoogleGenerativeAI("apikey");
83+
expect(() =>
84+
genAI.getGenerativeModelFromCachedContent({
85+
name: "cachename",
86+
contents: fakeContents,
87+
}),
88+
).to.throw("Cached content must contain a `model` field.");
89+
});
90+
it("getGenerativeModelFromCachedContent throws if mismatched model", () => {
91+
const genAI = new GoogleGenerativeAI("apikey");
92+
expect(() =>
93+
genAI.getGenerativeModelFromCachedContent(
94+
{
95+
name: "cachename",
96+
model: "my-model",
97+
contents: fakeContents,
98+
},
99+
{ model: "your-model" },
100+
),
101+
).to.throw(
102+
`Different value for "model" specified in modelParams (your-model) and cachedContent (my-model)`,
103+
);
104+
});
105+
it("getGenerativeModelFromCachedContent throws if mismatched systemInstruction", () => {
106+
const genAI = new GoogleGenerativeAI("apikey");
107+
expect(() =>
108+
genAI.getGenerativeModelFromCachedContent(
109+
{
110+
name: "cachename",
111+
model: "my-model",
112+
contents: fakeContents,
113+
systemInstruction: "hi",
114+
},
115+
{ model: "models/my-model", systemInstruction: "yo" },
116+
),
117+
).to.throw(
118+
`Different value for "systemInstruction" specified in modelParams (yo) and cachedContent (hi)`,
119+
);
120+
});
34121
});

src/gen-ai.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ export class GoogleGenerativeAI {
5353
*/
5454
getGenerativeModelFromCachedContent(
5555
cachedContent: CachedContent,
56+
modelParams?: Partial<ModelParams>,
5657
requestOptions?: RequestOptions,
5758
): GenerativeModel {
5859
if (!cachedContent.name) {
@@ -65,7 +66,40 @@ export class GoogleGenerativeAI {
6566
"Cached content must contain a `model` field.",
6667
);
6768
}
69+
70+
/**
71+
* Not checking tools and toolConfig for now as it would require a deep
72+
* equality comparison and isn't likely to be a common case.
73+
*/
74+
const disallowedDuplicates: Array<keyof ModelParams & keyof CachedContent> =
75+
["model", "systemInstruction"];
76+
77+
for (const key of disallowedDuplicates) {
78+
if (
79+
modelParams?.[key] &&
80+
cachedContent[key] &&
81+
modelParams?.[key] !== cachedContent[key]
82+
) {
83+
if (key === "model") {
84+
const modelParamsComp = modelParams.model.startsWith("models/")
85+
? modelParams.model.replace("models/", "")
86+
: modelParams.model;
87+
const cachedContentComp = cachedContent.model.startsWith("models/")
88+
? cachedContent.model.replace("models/", "")
89+
: cachedContent.model;
90+
if (modelParamsComp === cachedContentComp) {
91+
continue;
92+
}
93+
}
94+
throw new GoogleGenerativeAIRequestInputError(
95+
`Different value for "${key}" specified in modelParams` +
96+
` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`,
97+
);
98+
}
99+
}
100+
68101
const modelParamsFromCache: ModelParams = {
102+
...modelParams,
69103
model: cachedContent.model,
70104
tools: cachedContent.tools,
71105
toolConfig: cachedContent.toolConfig,

0 commit comments

Comments
 (0)