Skip to content

Commit 3639020

Browse files
committed
add unit tests for edge cases and errors in responseStreamer
1 parent 72dca78 commit 3639020

File tree

2 files changed

+87
-4
lines changed

2 files changed

+87
-4
lines changed

+llms/+stream/responseStreamer.m

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@
2525
end
2626
end
2727
end
28-
28+
2929
methods
3030
function [len,stop] = putData(this, data)
3131
[len,stop] = this.putData@matlab.net.http.io.BinaryConsumer(data);
32-
32+
stop = doPutData(this, data, stop);
33+
end
34+
end
35+
36+
methods (Access=?tresponseStreamer)
37+
function stop = doPutData(this, data, stop)
3338
% Extract out the response text from the message
3439
str = native2unicode(data','UTF-8');
3540
str = this.Incomplete + string(str);
@@ -88,8 +93,13 @@
8893
end
8994
else
9095
txt = json.message.content;
91-
this.StreamFun(txt);
92-
this.ResponseText = [this.ResponseText txt];
96+
if strlength(txt) > 0
97+
this.StreamFun(txt);
98+
this.ResponseText = [this.ResponseText txt];
99+
end
100+
if isfield(json,"done")
101+
stop = json.done;
102+
end
93103
end
94104
end
95105
end

tests/tresponseStreamer.m

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
classdef tresponseStreamer < matlab.unittest.TestCase
2+
% Tests for llms.stream.reponseStreamer
3+
%
4+
% This test file contains unit tests, with a specific focus on edge cases that
5+
% are hard to trigger in end-to-end tests.
6+
7+
% Copyright 2024 The MathWorks, Inc.
8+
9+
methods (Test)
10+
function singleResponse(testCase)
11+
s = tracingStreamer;
12+
inp = 'data: {"choices":[{"content_filter_results":{},"delta":{"content":"foo","role":"assistant"}}]}';
13+
inp = [inp newline 'data: [DONE]'];
14+
inp = unicode2native(inp,"UTF-8").';
15+
testCase.verifyTrue(s.doPutData(inp,false));
16+
testCase.verifyEqual(s.StreamFun(),"foo");
17+
end
18+
19+
function skipEmpty(testCase)
20+
s = tracingStreamer;
21+
inp = [...
22+
'data: {"choices":[{"content_filter_results":{},"delta":{"content":"foo","role":"assistant"}}]}' newline ...
23+
'data: {"choices":[]}' newline ...
24+
'data: [DONE]'];
25+
inp = unicode2native(inp,"UTF-8").';
26+
testCase.verifyTrue(s.doPutData(inp,false));
27+
testCase.verifyEqual(s.StreamFun(),"foo");
28+
end
29+
30+
function splitResponse(testCase)
31+
% it can happen that the server sends packets split in the
32+
% middle of a JSON object. Hard to trigger on purpose.
33+
s = tracingStreamer;
34+
inp = 'data: {"choices":[{"content_filter_results":{},"delta":{"content":"foo","role":"assistant"}}]}';
35+
inp = unicode2native(inp,"UTF-8").';
36+
testCase.verifyFalse(s.doPutData(inp(1:42),false));
37+
testCase.verifyFalse(s.doPutData(inp(43:end),false));
38+
testCase.verifyEqual(s.StreamFun(),"foo");
39+
end
40+
41+
function ollamaFormat(testCase)
42+
s = tracingStreamer;
43+
inp = '{"model":"mistral","created_at":"2024-06-07T07:43:30.658793Z","message":{"role":"assistant","content":" Hello"},"done":false}';
44+
inp = unicode2native(inp,"UTF-8").';
45+
testCase.verifyFalse(s.doPutData(inp,false));
46+
inp = '{"model":"mistral","created_at":"2024-06-07T07:43:30.658793Z","message":{"role":"assistant","content":" World"},"done":true}';
47+
inp = unicode2native(inp,"UTF-8").';
48+
testCase.verifyTrue(s.doPutData(inp,false));
49+
testCase.verifyEqual(s.StreamFun(),[" Hello"," World"]);
50+
end
51+
52+
function badJSON(testCase)
53+
s = tracingStreamer;
54+
inp = 'data: {"choices":[{"content_filter_results":{};"delta":{"content":"foo","role":"assistant"}}]}';
55+
inp = [inp newline inp];
56+
inp = unicode2native(inp,"UTF-8").';
57+
testCase.verifyError(@() s.doPutData(inp,false),'llms:stream:responseStreamer:InvalidInput');
58+
testCase.verifyEmpty(s.StreamFun());
59+
end
60+
end
61+
end
62+
63+
function s = tracingStreamer
64+
data = strings(1, 0);
65+
function seen = sf(str)
66+
% Append streamed text to an empty string array of length 1
67+
if nargin > 0
68+
data = [data, str];
69+
end
70+
seen = data;
71+
end
72+
s = llms.stream.responseStreamer(@sf);
73+
end

0 commit comments

Comments
 (0)