Skip to content

Commit dbd6ff8

Browse files
valtzutryvin
andauthored
feat: add Google Gemini tool support (#331)
Add tool support to Google Gemini. Extracted from #320 with updates after #326 Co-authored-by: Vin Souza <[email protected]>
1 parent 5ee30df commit dbd6ff8

File tree

10 files changed

+496
-10
lines changed

10 files changed

+496
-10
lines changed

examples/google/toolcall.php

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<?php
2+
3+
use PhpLlm\LlmChain\Chain\Chain;
4+
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
5+
use PhpLlm\LlmChain\Chain\Toolbox\Tool\Clock;
6+
use PhpLlm\LlmChain\Chain\Toolbox\Toolbox;
7+
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
8+
use PhpLlm\LlmChain\Platform\Bridge\Google\PlatformFactory;
9+
use PhpLlm\LlmChain\Platform\Message\Message;
10+
use PhpLlm\LlmChain\Platform\Message\MessageBag;
11+
use Symfony\Component\Dotenv\Dotenv;
12+
13+
require_once dirname(__DIR__, 2).'/vendor/autoload.php';
14+
(new Dotenv())->loadEnv(dirname(__DIR__, 2).'/.env');
15+
16+
if (empty($_ENV['GOOGLE_API_KEY'])) {
17+
echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL;
18+
exit(1);
19+
}
20+
21+
$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']);
22+
$llm = new Gemini(Gemini::GEMINI_2_FLASH);
23+
24+
$toolbox = Toolbox::create(new Clock());
25+
$processor = new ChainProcessor($toolbox);
26+
$chain = new Chain($platform, $llm, [$processor], [$processor]);
27+
28+
$messages = new MessageBag(Message::ofUser('What time is it?'));
29+
$response = $chain->call($messages);
30+
31+
echo $response->getContent().\PHP_EOL;

src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,23 @@ protected function supportsModel(Model $model): bool
3535
*/
3636
public function normalize(mixed $data, ?string $format = null, array $context = []): array
3737
{
38-
return [
39-
['text' => $data->content],
40-
];
38+
$normalized = [];
39+
40+
if (isset($data->content)) {
41+
$normalized['text'] = $data->content;
42+
}
43+
44+
if (isset($data->toolCalls[0])) {
45+
$normalized['functionCall'] = [
46+
'id' => $data->toolCalls[0]->id,
47+
'name' => $data->toolCalls[0]->name,
48+
];
49+
50+
if ($data->toolCalls[0]->arguments) {
51+
$normalized['functionCall']['args'] = $data->toolCalls[0]->arguments;
52+
}
53+
}
54+
55+
return [$normalized];
4156
}
4257
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Platform\Bridge\Google\Contract;
6+
7+
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
8+
use PhpLlm\LlmChain\Platform\Contract\Normalizer\ModelContractNormalizer;
9+
use PhpLlm\LlmChain\Platform\Message\ToolCallMessage;
10+
use PhpLlm\LlmChain\Platform\Model;
11+
use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface;
12+
use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait;
13+
14+
/**
15+
* @author Valtteri R <[email protected]>
16+
*/
17+
final class ToolCallMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface
18+
{
19+
use NormalizerAwareTrait;
20+
21+
protected function supportedDataClass(): string
22+
{
23+
return ToolCallMessage::class;
24+
}
25+
26+
protected function supportsModel(Model $model): bool
27+
{
28+
return $model instanceof Gemini;
29+
}
30+
31+
/**
32+
* @param ToolCallMessage $data
33+
*
34+
* @return array{
35+
* functionResponse: array{
36+
* id: string,
37+
* name: string,
38+
* response: array<int|string, mixed>
39+
* }
40+
* }[]
41+
*/
42+
public function normalize(mixed $data, ?string $format = null, array $context = []): array
43+
{
44+
$responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content;
45+
46+
return [[
47+
'functionResponse' => array_filter([
48+
'id' => $data->toolCall->id,
49+
'name' => $data->toolCall->name,
50+
'response' => \is_array($responseContent) ? $responseContent : [
51+
'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses.
52+
],
53+
]),
54+
]];
55+
}
56+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
<?php
2+
3+
namespace PhpLlm\LlmChain\Platform\Bridge\Google\Contract;
4+
5+
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
6+
use PhpLlm\LlmChain\Platform\Contract\JsonSchema\Factory;
7+
use PhpLlm\LlmChain\Platform\Contract\Normalizer\ModelContractNormalizer;
8+
use PhpLlm\LlmChain\Platform\Model;
9+
use PhpLlm\LlmChain\Platform\Tool\Tool;
10+
11+
/**
12+
* @author Valtteri R <[email protected]>
13+
*
14+
* @phpstan-import-type JsonSchema from Factory
15+
*/
16+
final class ToolNormalizer extends ModelContractNormalizer
17+
{
18+
protected function supportedDataClass(): string
19+
{
20+
return Tool::class;
21+
}
22+
23+
protected function supportsModel(Model $model): bool
24+
{
25+
return $model instanceof Gemini;
26+
}
27+
28+
/**
29+
* @param Tool $data
30+
*
31+
* @return array{
32+
* functionDeclarations: array{
33+
* name: string,
34+
* description: string,
35+
* parameters: JsonSchema|array{type: 'object'}
36+
* }[]
37+
* }
38+
*/
39+
public function normalize(mixed $data, ?string $format = null, array $context = []): array
40+
{
41+
$parameters = $data->parameters;
42+
unset($parameters['additionalProperties']);
43+
44+
return [
45+
'functionDeclarations' => [
46+
[
47+
'description' => $data->description,
48+
'name' => $data->name,
49+
'parameters' => $parameters,
50+
],
51+
],
52+
];
53+
}
54+
}

src/Platform/Bridge/Google/Gemini.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options =
2727
Capability::INPUT_MESSAGES,
2828
Capability::INPUT_IMAGE,
2929
Capability::OUTPUT_STREAMING,
30+
Capability::TOOL_CALLING,
3031
];
3132

3233
parent::__construct($name, $capabilities, $options);

src/Platform/Bridge/Google/ModelHandler.php

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
use PhpLlm\LlmChain\Platform\Exception\RuntimeException;
88
use PhpLlm\LlmChain\Platform\Model;
99
use PhpLlm\LlmChain\Platform\ModelClientInterface;
10+
use PhpLlm\LlmChain\Platform\Response\Choice;
11+
use PhpLlm\LlmChain\Platform\Response\ChoiceResponse;
1012
use PhpLlm\LlmChain\Platform\Response\ResponseInterface as LlmResponse;
1113
use PhpLlm\LlmChain\Platform\Response\StreamResponse;
1214
use PhpLlm\LlmChain\Platform\Response\TextResponse;
15+
use PhpLlm\LlmChain\Platform\Response\ToolCall;
16+
use PhpLlm\LlmChain\Platform\Response\ToolCallResponse;
1317
use PhpLlm\LlmChain\Platform\ResponseConverterInterface;
1418
use Symfony\Component\HttpClient\EventSourceHttpClient;
1519
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
@@ -52,6 +56,12 @@ public function request(Model $model, array|string $payload, array $options = []
5256

5357
$generationConfig = ['generationConfig' => $options];
5458
unset($generationConfig['generationConfig']['stream']);
59+
unset($generationConfig['generationConfig']['tools']);
60+
61+
if (isset($options['tools'])) {
62+
$generationConfig['tools'] = $options['tools'];
63+
unset($options['tools']);
64+
}
5565

5666
return $this->httpClient->request('POST', $url, [
5767
'headers' => [
@@ -76,11 +86,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe
7686

7787
$data = $response->toArray();
7888

79-
if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
89+
if (!isset($data['candidates'][0]['content']['parts'][0])) {
8090
throw new RuntimeException('Response does not contain any content');
8191
}
8292

83-
return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']);
93+
/** @var Choice[] $choices */
94+
$choices = array_map($this->convertChoice(...), $data['candidates']);
95+
96+
if (1 !== \count($choices)) {
97+
return new ChoiceResponse(...$choices);
98+
}
99+
100+
if ($choices[0]->hasToolCall()) {
101+
return new ToolCallResponse(...$choices[0]->getToolCalls());
102+
}
103+
104+
return new TextResponse($choices[0]->getContent());
84105
}
85106

86107
private function convertStream(ResponseInterface $response): \Generator
@@ -114,12 +135,68 @@ private function convertStream(ResponseInterface $response): \Generator
114135
throw new RuntimeException('Failed to decode JSON response', 0, $e);
115136
}
116137

117-
if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
138+
/** @var Choice[] $choices */
139+
$choices = array_map($this->convertChoice(...), $data['candidates'] ?? []);
140+
141+
if (!$choices) {
118142
continue;
119143
}
120144

121-
yield $data['candidates'][0]['content']['parts'][0]['text'];
145+
if (1 !== \count($choices)) {
146+
yield new ChoiceResponse(...$choices);
147+
continue;
148+
}
149+
150+
if ($choices[0]->hasToolCall()) {
151+
yield new ToolCallResponse(...$choices[0]->getToolCalls());
152+
}
153+
154+
if ($choices[0]->hasContent()) {
155+
yield $choices[0]->getContent();
156+
}
122157
}
123158
}
124159
}
160+
161+
/**
162+
* @param array{
163+
* finishReason?: string,
164+
* content: array{
165+
* parts: array{
166+
* functionCall?: array{
167+
* id: string,
168+
* name: string,
169+
* args: mixed[]
170+
* },
171+
* text?: string
172+
* }[]
173+
* }
174+
* } $choice
175+
*/
176+
private function convertChoice(array $choice): Choice
177+
{
178+
$contentPart = $choice['content']['parts'][0] ?? [];
179+
180+
if (isset($contentPart['functionCall'])) {
181+
return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]);
182+
}
183+
184+
if (isset($contentPart['text'])) {
185+
return new Choice($contentPart['text']);
186+
}
187+
188+
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason']));
189+
}
190+
191+
/**
192+
* @param array{
193+
* id: string,
194+
* name: string,
195+
* args: mixed[]
196+
* } $toolCall
197+
*/
198+
private function convertToolCall(array $toolCall): ToolCall
199+
{
200+
return new ToolCall($toolCall['id'] ?? '', $toolCall['name'], $toolCall['args']);
201+
}
125202
}

src/Platform/Bridge/Google/PlatformFactory.php

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\AssistantMessageNormalizer;
88
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\MessageBagNormalizer;
9+
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer;
10+
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolNormalizer;
911
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\UserMessageNormalizer;
1012
use PhpLlm\LlmChain\Platform\Contract;
1113
use PhpLlm\LlmChain\Platform\Platform;
@@ -28,6 +30,8 @@ public static function create(
2830
return new Platform([$responseHandler], [$responseHandler], Contract::create(
2931
new AssistantMessageNormalizer(),
3032
new MessageBagNormalizer(),
33+
new ToolNormalizer(),
34+
new ToolCallMessageNormalizer(),
3135
new UserMessageNormalizer(),
3236
));
3337
}

tests/Platform/Bridge/Google/Contract/AssistantMessageNormalizerTest.php

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
use PhpLlm\LlmChain\Platform\Contract;
1010
use PhpLlm\LlmChain\Platform\Message\AssistantMessage;
1111
use PhpLlm\LlmChain\Platform\Model;
12+
use PhpLlm\LlmChain\Platform\Response\ToolCall;
1213
use PHPUnit\Framework\Attributes\CoversClass;
14+
use PHPUnit\Framework\Attributes\DataProvider;
1315
use PHPUnit\Framework\Attributes\Small;
1416
use PHPUnit\Framework\Attributes\Test;
1517
use PHPUnit\Framework\Attributes\UsesClass;
@@ -20,6 +22,7 @@
2022
#[UsesClass(Gemini::class)]
2123
#[UsesClass(AssistantMessage::class)]
2224
#[UsesClass(Model::class)]
25+
#[UsesClass(ToolCall::class)]
2326
final class AssistantMessageNormalizerTest extends TestCase
2427
{
2528
#[Test]
@@ -41,14 +44,33 @@ public function getSupportedTypes(): void
4144
self::assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null));
4245
}
4346

47+
#[DataProvider('normalizeDataProvider')]
4448
#[Test]
45-
public function normalize(): void
49+
public function normalize(AssistantMessage $message, array $expectedOutput): void
4650
{
4751
$normalizer = new AssistantMessageNormalizer();
48-
$message = new AssistantMessage('Great to meet you. What would you like to know?');
4952

5053
$normalized = $normalizer->normalize($message);
5154

52-
self::assertSame([['text' => 'Great to meet you. What would you like to know?']], $normalized);
55+
self::assertSame($expectedOutput, $normalized);
56+
}
57+
58+
/**
59+
* @return iterable<string, array{AssistantMessage, array{text?: string, functionCall?: array{id: string, name: string, args?: mixed}}[]}>
60+
*/
61+
public static function normalizeDataProvider(): iterable
62+
{
63+
yield 'assistant message' => [
64+
new AssistantMessage('Great to meet you. What would you like to know?'),
65+
[['text' => 'Great to meet you. What would you like to know?']],
66+
];
67+
yield 'function call' => [
68+
new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1', ['arg1' => '123'])]),
69+
[['functionCall' => ['id' => 'id1', 'name' => 'name1', 'args' => ['arg1' => '123']]]],
70+
];
71+
yield 'function call without parameters' => [
72+
new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1')]),
73+
[['functionCall' => ['id' => 'id1', 'name' => 'name1']]],
74+
];
5375
}
5476
}

0 commit comments

Comments
 (0)