9
9
10
10
url = os .environ .get ("GRADIO_URL" , "http://localhost:7860" )
11
11
client = Client (url )
12
-
13
- class TestSuite (unittest .TestCase ):
14
- # General tests
12
+ latest_message = "Why don't humans drink horse milk?"
13
+ history = [
14
+ {
15
+ "role" : "user" ,
16
+ "metadata" : None ,
17
+ "content" : "Hi!" ,
18
+ "options" : None ,
19
+ },
20
+ {
21
+ "role" : "assistant" ,
22
+ "metadata" : None ,
23
+ "content" : "Hello! How can I help you?" ,
24
+ "options" : None ,
25
+ },
26
+ ]
27
+
28
+ class TestAPI (unittest .TestCase ):
15
29
def test_gradio_api (self ):
16
30
result = client .predict ("Hi" , api_name = "/chat" )
17
31
self .assertGreater (len (result ), 0 )
18
32
19
- # build_chat_context function tests
33
+ class TestBuildChatContext ( unittest . TestCase ):
20
34
@patch ("app.settings" )
21
35
@patch ("app.INCLUDE_SYSTEM_PROMPT" , True )
22
36
def test_chat_context_system_prompt (self , mock_settings ):
23
37
mock_settings .model_instruction = "You are a helpful assistant."
24
- latest_message = "What is a mammal?"
25
- history = [
26
- {'role' : 'user' , 'metadata' : None , 'content' : 'Hi!' , 'options' : None },
27
- {"role" : "assistant" , 'metadata' : None , "content" : "Hello! How can I help you?" , 'options' : None },
28
- ]
29
38
30
39
context = build_chat_context (latest_message , history )
31
40
32
41
self .assertEqual (len (context ), 4 )
33
42
self .assertIsInstance (context [0 ], SystemMessage )
34
43
self .assertEqual (context [0 ].content , "You are a helpful assistant." )
35
44
self .assertIsInstance (context [1 ], HumanMessage )
36
- self .assertEqual (context [1 ].content , "Hi!" )
45
+ self .assertEqual (context [1 ].content , history [ 0 ][ "content" ] )
37
46
self .assertIsInstance (context [2 ], AIMessage )
38
- self .assertEqual (context [2 ].content , "Hello! How can I help you?" )
47
+ self .assertEqual (context [2 ].content , history [ 1 ][ "content" ] )
39
48
self .assertIsInstance (context [3 ], HumanMessage )
40
49
self .assertEqual (context [3 ].content , latest_message )
41
50
42
51
@patch ("app.settings" )
43
52
@patch ("app.INCLUDE_SYSTEM_PROMPT" , False )
44
53
def test_chat_context_human_prompt (self , mock_settings ):
45
54
mock_settings .model_instruction = "You are a very helpful assistant."
46
- latest_message = "What is a fish?"
47
- history = [
48
- {"role" : "user" , 'metadata' : None , "content" : "Hi there!" , 'options' : None },
49
- {"role" : "assistant" , 'metadata' : None , "content" : "Hi! How can I help you?" , 'options' : None },
50
- ]
51
55
52
56
context = build_chat_context (latest_message , history )
53
57
54
58
self .assertEqual (len (context ), 3 )
55
59
self .assertIsInstance (context [0 ], HumanMessage )
56
- self .assertEqual (context [0 ].content , "You are a very helpful assistant.\n \n Hi there !" )
60
+ self .assertEqual (context [0 ].content , "You are a very helpful assistant.\n \n Hi!" )
57
61
self .assertIsInstance (context [1 ], AIMessage )
58
- self .assertEqual (context [1 ].content , "Hi! How can I help you?" )
62
+ self .assertEqual (context [1 ].content , history [ 1 ][ "content" ] )
59
63
self .assertIsInstance (context [2 ], HumanMessage )
60
64
self .assertEqual (context [2 ].content , latest_message )
61
65
62
- # inference function tests
66
+ class TestInference ( unittest . TestCase ):
63
67
@patch ("app.settings" )
64
68
@patch ("app.llm" )
65
69
@patch ("app.log" )
66
70
def test_inference_success (self , mock_logger , mock_llm , mock_settings ):
67
71
mock_llm .stream .return_value = [MagicMock (content = "response_chunk" )]
68
72
69
73
mock_settings .model_instruction = "You are a very helpful assistant."
70
- latest_message = "Why don't we drink horse milk?"
71
- history = [
72
- {"role" : "user" , 'metadata' : None , "content" : "Hi there!" , 'options' : None },
73
- {"role" : "assistant" , 'metadata' : None , "content" : "Hi! How can I help you?" , 'options' : None },
74
- ]
75
74
76
75
responses = list (inference (latest_message , history ))
77
76
@@ -88,8 +87,6 @@ def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
88
87
MagicMock (content = "</think>" ),
89
88
MagicMock (content = "final response" ),
90
89
]
91
- latest_message = "Hello"
92
- history = []
93
90
94
91
responses = list (inference (latest_message , history ))
95
92
@@ -98,7 +95,8 @@ def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
98
95
@patch ("app.llm" )
99
96
@patch ("app.INCLUDE_SYSTEM_PROMPT" , True )
100
97
@patch ("app.build_chat_context" )
101
- def test_inference_PossibleSystemPromptException (self , mock_build_chat_context , mock_llm ):
98
+ @patch ("app.log" )
99
+ def test_inference_PossibleSystemPromptException (self , mock_logger , mock_build_chat_context , mock_llm ):
102
100
mock_build_chat_context .return_value = ["mock_context" ]
103
101
mock_response = Mock ()
104
102
mock_response .json .return_value = {"message" : "Bad request" }
@@ -109,16 +107,15 @@ def test_inference_PossibleSystemPromptException(self, mock_build_chat_context,
109
107
body = None
110
108
)
111
109
112
- latest_message = "Hello"
113
- history = []
114
-
115
110
with self .assertRaises (PossibleSystemPromptException ):
116
111
list (inference (latest_message , history ))
112
+ mock_logger .error .assert_called_once_with ("Received BadRequestError from backend API: %s" , mock_llm .stream .side_effect )
117
113
118
114
@patch ("app.llm" )
119
115
@patch ("app.INCLUDE_SYSTEM_PROMPT" , False )
120
116
@patch ("app.build_chat_context" )
121
- def test_inference_general_error (self , mock_build_chat_context , mock_llm ):
117
+ @patch ("app.log" )
118
+ def test_inference_general_error (self , mock_logger , mock_build_chat_context , mock_llm ):
122
119
mock_build_chat_context .return_value = ["mock_context" ]
123
120
mock_response = Mock ()
124
121
mock_response .json .return_value = {"message" : "Bad request" }
@@ -129,13 +126,12 @@ def test_inference_general_error(self, mock_build_chat_context, mock_llm):
129
126
body = None
130
127
)
131
128
132
- latest_message = "Hello"
133
- history = []
134
129
exception_message = "\' API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: Bad request\' "
135
130
136
131
with self .assertRaises (gr .Error ) as gradio_error :
137
132
list (inference (latest_message , history ))
138
133
self .assertEqual (str (gradio_error .exception ), exception_message )
134
+ mock_logger .error .assert_called_once_with ("Received BadRequestError from backend API: %s" , mock_llm .stream .side_effect )
139
135
140
136
@patch ("app.llm" )
141
137
@patch ("app.build_chat_context" )
@@ -152,9 +148,6 @@ def test_inference_APIConnectionError(self, mock_gr, mock_logger, mock_build_cha
152
148
request = mock_request ,
153
149
)
154
150
155
- latest_message = "Hello"
156
- history = []
157
-
158
151
list (inference (latest_message , history ))
159
152
mock_logger .info .assert_any_call ("Backend API not yet ready" )
160
153
mock_gr .Info .assert_any_call ("Backend not ready - model may still be initialising - please try again later." )
@@ -174,9 +167,6 @@ def test_inference_APIConnectionError_initialised(self, mock_gr, mock_logger, mo
174
167
request = mock_request ,
175
168
)
176
169
177
- latest_message = "Hello"
178
- history = []
179
-
180
170
list (inference (latest_message , history ))
181
171
mock_logger .error .assert_called_once_with ("Failed to connect to backend API: %s" , mock_llm .stream .side_effect )
182
172
mock_gr .Warning .assert_any_call ("Failed to connect to backend API." )
@@ -195,11 +185,8 @@ def test_inference_InternalServerError(self, mock_gr, mock_build_chat_context, m
195
185
body = None
196
186
)
197
187
198
- latest_message = "Hello"
199
- history = []
200
-
201
188
list (inference (latest_message , history ))
202
189
mock_gr .Warning .assert_any_call ("Internal server error encountered in backend API - see API logs for details." )
203
190
204
191
if __name__ == "__main__" :
205
- unittest .main ()
192
+ unittest .main (verbosity = 2 )
0 commit comments