Skip to content

Add unit tests for inference function #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions web-apps/chat/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ARG DIR=chat

COPY $DIR/requirements.txt requirements.txt
COPY utils utils
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir --upgrade setuptools
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh purge-google-fonts.sh
Expand Down
43 changes: 25 additions & 18 deletions web-apps/chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,37 @@ class PossibleSystemPromptException(Exception):
streaming=True,
)

def build_chat_context(latest_message, history):
"""
Build the chat context from the latest message and history.
"""
context = []
if INCLUDE_SYSTEM_PROMPT:
context.append(SystemMessage(content=settings.model_instruction))
elif history and len(history) > 0:
# Mimic system prompt by prepending it to first human message
history[0]['content'] = f"{settings.model_instruction}\n\n{history[0]['content']}"

for message in history:
role = message['role']
content = message['content']
if role == "user":
context.append(HumanMessage(content=content))
else:
if role != "assistant":
log.warn(f"Message role {role} converted to 'assistant'")
context.append(AIMessage(content=(content or "")))
context.append(HumanMessage(content=latest_message))
return context


def inference(latest_message, history):
# Allow mutating global variable
global BACKEND_INITIALISED
log.debug("Inference request received with history: %s", history)

try:
context = []
if INCLUDE_SYSTEM_PROMPT:
context.append(SystemMessage(content=settings.model_instruction))
elif history and len(history) > 0:
# Mimic system prompt by prepending it to first human message
history[0]['content'] = f"{settings.model_instruction}\n\n{history[0]['content']}"

for message in history:
role = message['role']
content = message['content']
if role == "user":
context.append(HumanMessage(content=content))
else:
if role != "assistant":
log.warn(f"Message role {role} converted to 'assistant'")
context.append(AIMessage(content=(content or "")))
context.append(HumanMessage(content=latest_message))

context = build_chat_context(latest_message, history)
log.debug("Chat context: %s", context)

response = ""
Expand Down
187 changes: 178 additions & 9 deletions web-apps/chat/test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,192 @@
import openai
import os
import unittest

# from unittest import mock
from gradio_client import Client
from unittest.mock import patch, MagicMock, Mock
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from app import build_chat_context, inference, PossibleSystemPromptException, gr

url = os.environ.get("GRADIO_URL", "http://localhost:7860")
client = Client(url)
latest_message = "Why don't humans drink horse milk?"
history = [
{
"role": "user",
"metadata": None,
"content": "Hi!",
"options": None,
},
{
"role": "assistant",
"metadata": None,
"content": "Hello! How can I help you?",
"options": None,
},
]

class TestSuite(unittest.TestCase):

class TestAPI(unittest.TestCase):
def test_gradio_api(self):
result = client.predict("Hi", api_name="/chat")
self.assertGreater(len(result), 0)

# def test_mock_response(self):
# with mock.patch('app.client.stream_response', return_value=(char for char in "Mocked")) as mock_response:
# result = client.predict("Hi", api_name="/chat")
# # mock_response.assert_called_once_with("Hi", [])
# self.assertEqual(result, "Mocked")
class TestBuildChatContext(unittest.TestCase):
@patch("app.settings")
@patch("app.INCLUDE_SYSTEM_PROMPT", True)
def test_chat_context_system_prompt(self, mock_settings):
mock_settings.model_instruction = "You are a helpful assistant."

context = build_chat_context(latest_message, history)

self.assertEqual(len(context), 4)
self.assertIsInstance(context[0], SystemMessage)
self.assertEqual(context[0].content, "You are a helpful assistant.")
self.assertIsInstance(context[1], HumanMessage)
self.assertEqual(context[1].content, history[0]["content"])
self.assertIsInstance(context[2], AIMessage)
self.assertEqual(context[2].content, history[1]["content"])
self.assertIsInstance(context[3], HumanMessage)
self.assertEqual(context[3].content, latest_message)

@patch("app.settings")
@patch("app.INCLUDE_SYSTEM_PROMPT", False)
def test_chat_context_human_prompt(self, mock_settings):
mock_settings.model_instruction = "You are a very helpful assistant."

context = build_chat_context(latest_message, history)

self.assertEqual(len(context), 3)
self.assertIsInstance(context[0], HumanMessage)
self.assertEqual(context[0].content, "You are a very helpful assistant.\n\nHi!")
self.assertIsInstance(context[1], AIMessage)
self.assertEqual(context[1].content, history[1]["content"])
self.assertIsInstance(context[2], HumanMessage)
self.assertEqual(context[2].content, latest_message)

class TestInference(unittest.TestCase):
@patch("app.settings")
@patch("app.llm")
@patch("app.log")
def test_inference_success(self, mock_logger, mock_llm, mock_settings):
mock_llm.stream.return_value = [MagicMock(content="response_chunk")]

mock_settings.model_instruction = "You are a very helpful assistant."

responses = list(inference(latest_message, history))

self.assertEqual(responses, ["response_chunk"])
mock_logger.debug.assert_any_call("Inference request received with history: %s", history)

@patch("app.llm")
@patch("app.build_chat_context")
def test_inference_thinking_tags(self, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_llm.stream.return_value = [
MagicMock(content="<think>"),
MagicMock(content="processing"),
MagicMock(content="</think>"),
MagicMock(content="final response"),
]

responses = list(inference(latest_message, history))

self.assertEqual(responses, ["Thinking...", "Thinking...", "", "final response"])

@patch("app.llm")
@patch("app.INCLUDE_SYSTEM_PROMPT", True)
@patch("app.build_chat_context")
@patch("app.log")
def test_inference_PossibleSystemPromptException(self, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_response = Mock()
mock_response.json.return_value = {"message": "Bad request"}

mock_llm.stream.side_effect = openai.BadRequestError(
message="Bad request",
response=mock_response,
body=None
)

with self.assertRaises(PossibleSystemPromptException):
list(inference(latest_message, history))
mock_logger.error.assert_called_once_with("Received BadRequestError from backend API: %s", mock_llm.stream.side_effect)

@patch("app.llm")
@patch("app.INCLUDE_SYSTEM_PROMPT", False)
@patch("app.build_chat_context")
@patch("app.log")
def test_inference_general_error(self, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_response = Mock()
mock_response.json.return_value = {"message": "Bad request"}

mock_llm.stream.side_effect = openai.BadRequestError(
message="Bad request",
response=mock_response,
body=None
)

exception_message = "\'API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: Bad request\'"

with self.assertRaises(gr.Error) as gradio_error:
list(inference(latest_message, history))
self.assertEqual(str(gradio_error.exception), exception_message)
mock_logger.error.assert_called_once_with("Received BadRequestError from backend API: %s", mock_llm.stream.side_effect)

@patch("app.llm")
@patch("app.build_chat_context")
@patch("app.log")
@patch("app.gr")
@patch("app.BACKEND_INITIALISED", False)
def test_inference_APIConnectionError(self, mock_gr, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_request = Mock()
mock_request.json.return_value = {"message": "Foo"}

mock_llm.stream.side_effect = openai.APIConnectionError(
message="Foo",
request=mock_request,
)

list(inference(latest_message, history))
mock_logger.info.assert_any_call("Backend API not yet ready")
mock_gr.Info.assert_any_call("Backend not ready - model may still be initialising - please try again later.")

@patch("app.llm")
@patch("app.build_chat_context")
@patch("app.log")
@patch("app.gr")
@patch("app.BACKEND_INITIALISED", True)
def test_inference_APIConnectionError_initialised(self, mock_gr, mock_logger, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_request = Mock()
mock_request.json.return_value = {"message": "Foo"}

mock_llm.stream.side_effect = openai.APIConnectionError(
message="Foo",
request=mock_request,
)

list(inference(latest_message, history))
mock_logger.error.assert_called_once_with("Failed to connect to backend API: %s", mock_llm.stream.side_effect)
mock_gr.Warning.assert_any_call("Failed to connect to backend API.")

@patch("app.llm")
@patch("app.build_chat_context")
@patch("app.gr")
def test_inference_InternalServerError(self, mock_gr, mock_build_chat_context, mock_llm):
mock_build_chat_context.return_value = ["mock_context"]
mock_request = Mock()
mock_request.json.return_value = {"message": "Foo"}

mock_llm.stream.side_effect = openai.InternalServerError(
message="Foo",
response=mock_request,
body=None
)

list(inference(latest_message, history))
mock_gr.Warning.assert_any_call("Internal server error encountered in backend API - see API logs for details.")

if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)
2 changes: 2 additions & 0 deletions web-apps/flux-image-gen/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
ARG DIR=flux-image-gen

COPY $DIR/requirements.txt requirements.txt
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir --upgrade setuptools
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh .
Expand Down
2 changes: 2 additions & 0 deletions web-apps/image-analysis/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ARG DIR=image-analysis

COPY $DIR/requirements.txt requirements.txt
COPY utils utils
RUN pip install --no-cache-dir --upgrade pip
RUN pip install --no-cache-dir --upgrade setuptools
RUN pip install --no-cache-dir -r requirements.txt

COPY purge-google-fonts.sh purge-google-fonts.sh
Expand Down
2 changes: 1 addition & 1 deletion web-apps/test-images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test() {
--name $1-test-suite \
-e GRADIO_URL=http://$1-app:7860 --entrypoint python \
$IMAGE \
test.py
test.py -v

log "Removing containers:"
docker rm -f ollama $1-app
Expand Down