Skip to content

Commit c3bec04

Browse files
authored
Merge pull request #59 from matlab-deep-learning/ollama-images
Ollama images
2 parents c607360 + 7aaadb4 commit c3bec04

File tree

11 files changed

+183
-33
lines changed

11 files changed

+183
-33
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ jobs:
3030
- name: Pull models
3131
run: |
3232
ollama pull mistral
33+
ollama pull bakllava
3334
OLLAMA_HOST=127.0.0.1:11435 ollama pull qwen2:0.5b
3435
- name: Set up MATLAB
3536
uses: matlab-actions/setup-matlab@v2

azureChat.m

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@
191191
if isstring(messages) && isscalar(messages)
192192
messagesStruct = {struct("role", "user", "content", messages)};
193193
else
194-
messagesStruct = messages.Messages;
194+
messagesStruct = this.encodeImages(messages.Messages);
195195
end
196196

197197
if ~isempty(this.SystemPrompt)
@@ -251,6 +251,40 @@ function mustBeValidFunctionCall(this, functionCall)
251251
end
252252

253253
end
254+
255+
function messageStruct = encodeImages(~, messageStruct)
256+
for k=1:numel(messageStruct)
257+
if isfield(messageStruct{k},"images")
258+
images = messageStruct{k}.images;
259+
detail = messageStruct{k}.image_detail;
260+
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
261+
messageStruct{k}.content = ...
262+
{struct("type","text","text",messageStruct{k}.content)};
263+
for img = images(:).'
264+
if startsWith(img,("https://"|"http://"))
265+
s = struct( ...
266+
"type","image_url", ...
267+
"image_url",struct("url",img));
268+
else
269+
[~,~,ext] = fileparts(img);
270+
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
271+
% Base64 encode the image using the given MIME type
272+
fid = fopen(img);
273+
im = fread(fid,'*uint8');
274+
fclose(fid);
275+
b64 = matlab.net.base64encode(im);
276+
s = struct( ...
277+
"type","image_url", ...
278+
"image_url",struct("url",MIMEType + b64));
279+
end
280+
281+
s.image_url.detail = detail;
282+
283+
messageStruct{k}.content{end+1} = s;
284+
end
285+
end
286+
end
287+
end
254288
end
255289
end
256290

doc/Azure.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital
115115
% Should stream the response token by token
116116
```
117117

118+
## Understanding the content of an image
119+
120+
You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding.
121+
```matlab
122+
chat = azureChat("You are an AI assistant.",Deployment="gpt-4o");
123+
image_path = "peppers.png";
124+
messages = messageHistory;
125+
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
126+
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
127+
txt
128+
% outputs a description of the image
129+
```
130+
118131
## Calling MATLAB functions with the API
119132

120133
Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications.

doc/Ollama.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,24 @@ txt = generate(chat,"What is Model-Based Design and how is it related to Digital
9696
% Should stream the response token by token
9797
```
9898

99+
## Understanding the content of an image
100+
101+
You can use multimodal models like `llava` to experiment with image understanding.
102+
103+
> [!TIP]
104+
> Many models available for Ollama allow you to include images in the prompt, even if the model does not support image inputs. In that case, the images are silently removed from the input. This can result in unexpected outputs.
105+
106+
107+
```matlab
108+
chat = ollamaChat("llava");
109+
image_path = "peppers.png";
110+
messages = messageHistory;
111+
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
112+
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
113+
txt
114+
% outputs a description of the image
115+
```
116+
99117
## Establishing a connection to remote LLMs using Ollama
100118

101119
To connect to a remote Ollama server, use the `Endpoint` name-value pair. Include the server name and port number. Ollama starts on 11434 by default.

doc/OpenAI.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,15 @@ You can extract the arguments and write the data to a table, for example.
250250

251251
## Understanding the content of an image
252252

253-
You can use gpt-4-turbo to experiment with image understanding.
253+
You can use gpt-4o, gpt-4o-mini, or gpt-4-turbo to experiment with image understanding.
254254
```matlab
255-
chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-turbo");
255+
chat = openAIChat("You are an AI assistant.");
256256
image_path = "peppers.png";
257257
messages = messageHistory;
258258
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
259259
[txt,response] = generate(chat,messages,MaxNumTokens=4096);
260-
% Should output the description of the image
260+
txt
261+
% outputs a description of the image
261262
```
262263

263264
## Obtaining embeddings

messageHistory.m

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,9 @@
111111
nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto"
112112
end
113113

114-
newMessage = struct("role", "user", "content", []);
115-
newMessage.content = {struct("type","text","text",string(content))};
116-
for img = images(:).'
117-
if startsWith(img,("https://"|"http://"))
118-
s = struct( ...
119-
"type","image_url", ...
120-
"image_url",struct("url",img));
121-
else
122-
[~,~,ext] = fileparts(img);
123-
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
124-
% Base64 encode the image using the given MIME type
125-
fid = fopen(img);
126-
im = fread(fid,'*uint8');
127-
fclose(fid);
128-
b64 = matlab.net.base64encode(im);
129-
s = struct( ...
130-
"type","image_url", ...
131-
"image_url",struct("url",MIMEType + b64));
132-
end
133-
134-
s.image_url.detail = nvp.Detail;
135-
136-
newMessage.content{end+1} = s;
137-
this.Messages{end+1} = newMessage;
138-
end
139-
114+
newMessage = struct("role", "user", "content", string(content), ...
115+
"images", images, "image_detail", nvp.Detail);
116+
this.Messages{end+1} = newMessage;
140117
end
141118

142119
function this = addToolMessage(this, id, name, content)

ollamaChat.m

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
if isstring(messages) && isscalar(messages)
137137
messagesStruct = {struct("role", "user", "content", messages)};
138138
else
139-
messagesStruct = messages.Messages;
139+
messagesStruct = this.encodeImages(messages.Messages);
140140
end
141141

142142
if ~isempty(this.SystemPrompt)
@@ -160,6 +160,28 @@
160160
end
161161
end
162162

163+
methods (Access=private)
164+
function messageStruct = encodeImages(~, messageStruct)
165+
for k=1:numel(messageStruct)
166+
if isfield(messageStruct{k},"images")
167+
images = messageStruct{k}.images;
168+
% detail = messageStruct{k}.image_detail;
169+
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
170+
imgs = cell(size(images));
171+
for n = 1:numel(images)
172+
img = images(n);
173+
% Base64 encode the image
174+
fid = fopen(img);
175+
im = fread(fid,'*uint8');
176+
fclose(fid);
177+
imgs{n} = matlab.net.base64encode(im);
178+
end
179+
messageStruct{k}.images = imgs;
180+
end
181+
end
182+
end
183+
end
184+
163185
methods(Static)
164186
function mdls = models
165187
%ollamaChat.models - return models available on Ollama server

openAIChat.m

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
if isstring(messages) && isscalar(messages)
182182
messagesStruct = {struct("role", "user", "content", messages)};
183183
else
184-
messagesStruct = messages.Messages;
184+
messagesStruct = this.encodeImages(messages.Messages);
185185
end
186186

187187
llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName);
@@ -230,6 +230,40 @@ function mustBeValidFunctionCall(this, functionCall)
230230
end
231231

232232
end
233+
234+
function messageStruct = encodeImages(~, messageStruct)
235+
for k=1:numel(messageStruct)
236+
if isfield(messageStruct{k},"images")
237+
images = messageStruct{k}.images;
238+
detail = messageStruct{k}.image_detail;
239+
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
240+
messageStruct{k}.content = ...
241+
{struct("type","text","text",messageStruct{k}.content)};
242+
for img = images(:).'
243+
if startsWith(img,("https://"|"http://"))
244+
s = struct( ...
245+
"type","image_url", ...
246+
"image_url",struct("url",img));
247+
else
248+
[~,~,ext] = fileparts(img);
249+
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
250+
% Base64 encode the image using the given MIME type
251+
fid = fopen(img);
252+
im = fread(fid,'*uint8');
253+
fclose(fid);
254+
b64 = matlab.net.base64encode(im);
255+
s = struct( ...
256+
"type","image_url", ...
257+
"image_url",struct("url",MIMEType + b64));
258+
end
259+
260+
s.image_url.detail = detail;
261+
262+
messageStruct{k}.content{end+1} = s;
263+
end
264+
end
265+
end
266+
end
233267
end
234268
end
235269

tests/tazureChat.m

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ function generateMultipleResponses(testCase)
5555
testCase.verifySize(response.Body.Data.choices,[3,1]);
5656
end
5757

58+
function generateWithImage(testCase)
59+
chat = azureChat(Deployment="gpt-4o");
60+
image_path = "peppers.png";
61+
emptyMessages = messageHistory;
62+
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
63+
64+
text = generate(chat,messages);
65+
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
66+
end
67+
68+
function generateWithMultipleImages(testCase)
69+
import matlab.unittest.constraints.ContainsSubstring
70+
chat = azureChat(Deployment="gpt-4o");
71+
image_path = "peppers.png";
72+
emptyMessages = messageHistory;
73+
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);
74+
75+
text = generate(chat,messages);
76+
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
77+
end
5878

5979
function doReturnErrors(testCase)
6080
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
@@ -65,6 +85,15 @@ function doReturnErrors(testCase)
6585
testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError");
6686
end
6787

88+
function generateWithImageErrorsForGpt35(testCase)
89+
chat = azureChat;
90+
image_path = "peppers.png";
91+
emptyMessages = messageHistory;
92+
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
93+
94+
testCase.verifyError(@() generate(chat,messages), "llms:apiReturnedError");
95+
end
96+
6897
function seedFixesResult(testCase)
6998
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
7099
chat = azureChat;

tests/tollamaChat.m

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ function seedFixesResult(testCase)
9898
testCase.verifyEqual(response1,response2);
9999
end
100100

101+
function generateWithImages(testCase)
102+
chat = ollamaChat("bakllava");
103+
image_path = "peppers.png";
104+
emptyMessages = messageHistory;
105+
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
106+
107+
text = generate(chat,messages);
108+
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
109+
end
110+
101111
function streamFunc(testCase)
102112
function seen = sf(str)
103113
persistent data;

tests/topenAIChat.m

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function generateWithToolsAndStreamFunc(testCase)
173173
testCase.verifyThat(data,HasField("explanation"));
174174
end
175175

176-
function generateWithImages(testCase)
176+
function generateWithImage(testCase)
177177
chat = openAIChat;
178178
image_path = "peppers.png";
179179
emptyMessages = messageHistory;
@@ -183,6 +183,17 @@ function generateWithImages(testCase)
183183
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
184184
end
185185

186+
function generateWithMultipleImages(testCase)
187+
import matlab.unittest.constraints.ContainsSubstring
188+
chat = openAIChat;
189+
image_path = "peppers.png";
190+
emptyMessages = messageHistory;
191+
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);
192+
193+
text = generate(chat,messages);
194+
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
195+
end
196+
186197
function invalidInputsGenerate(testCase, InvalidGenerateInput)
187198
f = openAIFunction("validfunction");
188199
chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key");

0 commit comments

Comments
 (0)