Skip to content

Commit 2f1ab85

Browse files
authored
Fix Mistral empty tool_call_end flipping state machine to normal (#1151)
1 parent f3bb10c commit 2f1ab85

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

mlx_lm/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,11 @@ def _make_state_machine(
674674
ts = tokenizer.tool_call_start_tokens
675675
te = tokenizer.tool_call_end_tokens
676676
transitions["normal"].append((ts, "tool"))
677-
transitions["tool"] = [(te, "normal")]
677+
transitions["tool"] = [(te, "normal")] if te else []
678678
transitions["tool"].extend(common_stops)
679679
sequences[ts] = tokenizer.tool_call_start
680-
sequences[te] = tokenizer.tool_call_end
680+
if te:
681+
sequences[te] = tokenizer.tool_call_end
681682

682683
sm = SequenceStateMachine(transitions, initial=initial_state)
683684
if len(self._state_machine_cache) > 100:

tests/test_server.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,33 @@ def test_handle_chat_completions_with_null_tool_content(self):
277277
self.assertIn("id", response_body)
278278
self.assertIn("choices", response_body)
279279

280+
def test_make_state_machine_empty_tool_call_end(self):
281+
class FakeTokenizer:
282+
has_thinking = False
283+
has_tool_calling = True
284+
tool_call_start = "[TOOL_CALLS]"
285+
tool_call_end = ""
286+
tool_call_start_tokens = (100,)
287+
tool_call_end_tokens = ()
288+
eos_token_ids = [2]
289+
290+
def convert_ids_to_tokens(self, t):
291+
return f"<eos{t}>"
292+
293+
sm, _ = self.response_generator._make_state_machine(
294+
("fake-empty-end", None, None),
295+
FakeTokenizer(),
296+
stop_words=[],
297+
)
298+
state = sm.make_state()
299+
state, _, s = sm.match(state, 100)
300+
self.assertEqual(s, "tool")
301+
for tok in [42, 43, 44]:
302+
state, _, s = sm.match(state, tok)
303+
self.assertEqual(s, "tool")
304+
state, _, s = sm.match(state, 2)
305+
self.assertIsNone(s)
306+
280307
def test_handle_models(self):
281308
url = f"http://localhost:{self.port}/v1/models"
282309
response = requests.get(url)

0 commit comments

Comments
 (0)