Skip to content

Commit 8da8794

Browse files
committed
Adding image support to ollamaChat
* Refactor `messageHistory` to be agnostic of the image encoding. * Add backend encoding of images to `openAIChat`, `azureChat`, and `ollamaChat`. * Add image test points to the test files. Open question: Can we reliably detect which Ollama models support vision?
1 parent bdb84b8 commit 8da8794

File tree

7 files changed

+147
-30
lines changed

7 files changed

+147
-30
lines changed

azureChat.m

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@
191191
if isstring(messages) && isscalar(messages)
192192
messagesStruct = {struct("role", "user", "content", messages)};
193193
else
194-
messagesStruct = messages.Messages;
194+
messagesStruct = this.encodeImages(messages.Messages);
195195
end
196196

197197
if ~isempty(this.SystemPrompt)
@@ -251,6 +251,40 @@ function mustBeValidFunctionCall(this, functionCall)
251251
end
252252

253253
end
254+
255+
function messageStruct = encodeImages(~, messageStruct)
256+
for k=1:numel(messageStruct)
257+
if isfield(messageStruct{k},"images")
258+
images = messageStruct{k}.images;
259+
detail = messageStruct{k}.image_detail;
260+
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
261+
messageStruct{k}.content = ...
262+
{struct("type","text","text",messageStruct{k}.content)};
263+
for img = images(:).'
264+
if startsWith(img,("https://"|"http://"))
265+
s = struct( ...
266+
"type","image_url", ...
267+
"image_url",struct("url",img));
268+
else
269+
[~,~,ext] = fileparts(img);
270+
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
271+
% Base64 encode the image using the given MIME type
272+
fid = fopen(img);
273+
im = fread(fid,'*uint8');
274+
fclose(fid);
275+
b64 = matlab.net.base64encode(im);
276+
s = struct( ...
277+
"type","image_url", ...
278+
"image_url",struct("url",MIMEType + b64));
279+
end
280+
281+
s.image_url.detail = detail;
282+
283+
messageStruct{k}.content{end+1} = s;
284+
end
285+
end
286+
end
287+
end
254288
end
255289
end
256290

messageHistory.m

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,9 @@
111111
nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto"
112112
end
113113

114-
newMessage = struct("role", "user", "content", []);
115-
newMessage.content = {struct("type","text","text",string(content))};
116-
for img = images(:).'
117-
if startsWith(img,("https://"|"http://"))
118-
s = struct( ...
119-
"type","image_url", ...
120-
"image_url",struct("url",img));
121-
else
122-
[~,~,ext] = fileparts(img);
123-
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
124-
% Base64 encode the image using the given MIME type
125-
fid = fopen(img);
126-
im = fread(fid,'*uint8');
127-
fclose(fid);
128-
b64 = matlab.net.base64encode(im);
129-
s = struct( ...
130-
"type","image_url", ...
131-
"image_url",struct("url",MIMEType + b64));
132-
end
133-
134-
s.image_url.detail = nvp.Detail;
135-
136-
newMessage.content{end+1} = s;
137-
this.Messages{end+1} = newMessage;
138-
end
139-
114+
newMessage = struct("role", "user", "content", string(content), ...
115+
"images", images, "image_detail", nvp.Detail);
116+
this.Messages{end+1} = newMessage;
140117
end
141118

142119
function this = addToolMessage(this, id, name, content)

ollamaChat.m

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
if isstring(messages) && isscalar(messages)
137137
messagesStruct = {struct("role", "user", "content", messages)};
138138
else
139-
messagesStruct = messages.Messages;
139+
messagesStruct = this.encodeImages(messages.Messages);
140140
end
141141

142142
if ~isempty(this.SystemPrompt)
@@ -160,6 +160,28 @@
160160
end
161161
end
162162

163+
methods (Access=private)
164+
function messageStruct = encodeImages(~, messageStruct)
165+
for k=1:numel(messageStruct)
166+
if isfield(messageStruct{k},"images")
167+
images = messageStruct{k}.images;
168+
% detail = messageStruct{k}.image_detail;
169+
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
170+
imgs = cell(size(images));
171+
for n = 1:numel(images)
172+
img = images(n);
173+
% Base64 encode the image
174+
fid = fopen(img);
175+
im = fread(fid,'*uint8');
176+
fclose(fid);
177+
imgs{n} = matlab.net.base64encode(im);
178+
end
179+
messageStruct{k}.images = imgs;
180+
end
181+
end
182+
end
183+
end
184+
163185
methods(Static)
164186
function mdls = models
165187
%ollamaChat.models - return models available on Ollama server

openAIChat.m

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
if isstring(messages) && isscalar(messages)
182182
messagesStruct = {struct("role", "user", "content", messages)};
183183
else
184-
messagesStruct = messages.Messages;
184+
messagesStruct = this.encodeImages(messages.Messages);
185185
end
186186

187187
llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName);
@@ -230,6 +230,40 @@ function mustBeValidFunctionCall(this, functionCall)
230230
end
231231

232232
end
233+
234+
function messageStruct = encodeImages(~, messageStruct)
235+
for k=1:numel(messageStruct)
236+
if isfield(messageStruct{k},"images")
237+
images = messageStruct{k}.images;
238+
detail = messageStruct{k}.image_detail;
239+
messageStruct{k} = rmfield(messageStruct{k},["images","image_detail"]);
240+
messageStruct{k}.content = ...
241+
{struct("type","text","text",messageStruct{k}.content)};
242+
for img = images(:).'
243+
if startsWith(img,("https://"|"http://"))
244+
s = struct( ...
245+
"type","image_url", ...
246+
"image_url",struct("url",img));
247+
else
248+
[~,~,ext] = fileparts(img);
249+
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
250+
% Base64 encode the image using the given MIME type
251+
fid = fopen(img);
252+
im = fread(fid,'*uint8');
253+
fclose(fid);
254+
b64 = matlab.net.base64encode(im);
255+
s = struct( ...
256+
"type","image_url", ...
257+
"image_url",struct("url",MIMEType + b64));
258+
end
259+
260+
s.image_url.detail = detail;
261+
262+
messageStruct{k}.content{end+1} = s;
263+
end
264+
end
265+
end
266+
end
233267
end
234268
end
235269

tests/tazureChat.m

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ function generateMultipleResponses(testCase)
5555
testCase.verifySize(response.Body.Data.choices,[3,1]);
5656
end
5757

58+
function generateWithImage(testCase)
59+
chat = azureChat(Deployment="gpt-4o");
60+
image_path = "peppers.png";
61+
emptyMessages = messageHistory;
62+
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
63+
64+
text = generate(chat,messages);
65+
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
66+
end
67+
68+
function generateWithMultipleImages(testCase)
69+
import matlab.unittest.constraints.ContainsSubstring
70+
chat = azureChat(Deployment="gpt-4o");
71+
image_path = "peppers.png";
72+
emptyMessages = messageHistory;
73+
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);
74+
75+
text = generate(chat,messages);
76+
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
77+
end
5878

5979
function doReturnErrors(testCase)
6080
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
@@ -65,6 +85,15 @@ function doReturnErrors(testCase)
6585
testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError");
6686
end
6787

88+
function generateWithImageErrorsForGpt35(testCase)
89+
chat = azureChat;
90+
image_path = "peppers.png";
91+
emptyMessages = messageHistory;
92+
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
93+
94+
testCase.verifyError(@() generate(chat,messages), "llms:apiReturnedError");
95+
end
96+
6897
function seedFixesResult(testCase)
6998
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
7099
chat = azureChat;

tests/tollamaChat.m

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ function seedFixesResult(testCase)
9898
testCase.verifyEqual(response1,response2);
9999
end
100100

101+
function generateWithImages(testCase)
102+
chat = ollamaChat("bakllava");
103+
image_path = "peppers.png";
104+
emptyMessages = messageHistory;
105+
messages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path);
106+
107+
text = generate(chat,messages);
108+
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
109+
end
110+
101111
function streamFunc(testCase)
102112
function seen = sf(str)
103113
persistent data;

tests/topenAIChat.m

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function generateWithToolsAndStreamFunc(testCase)
173173
testCase.verifyThat(data,HasField("explanation"));
174174
end
175175

176-
function generateWithImages(testCase)
176+
function generateWithImage(testCase)
177177
chat = openAIChat;
178178
image_path = "peppers.png";
179179
emptyMessages = messageHistory;
@@ -183,6 +183,17 @@ function generateWithImages(testCase)
183183
testCase.verifyThat(text,matlab.unittest.constraints.ContainsSubstring("pepper"));
184184
end
185185

186+
function generateWithMultipleImages(testCase)
187+
import matlab.unittest.constraints.ContainsSubstring
188+
chat = openAIChat;
189+
image_path = "peppers.png";
190+
emptyMessages = messageHistory;
191+
messages = addUserMessageWithImages(emptyMessages,"Compare these images.",[image_path,image_path]);
192+
193+
text = generate(chat,messages);
194+
testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical"));
195+
end
196+
186197
function invalidInputsGenerate(testCase, InvalidGenerateInput)
187198
f = openAIFunction("validfunction");
188199
chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key");

0 commit comments

Comments
 (0)