diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 09eaa47e..5bde2620 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -1080,6 +1080,8 @@ private func generateLoopTask( } } + handler.onGenerationEnd(emit: continuation.yield) + let now = Date.timeIntervalSinceReferenceDate let generateTime = now - start @@ -1292,6 +1294,11 @@ private protocol TokenLoopHandler: Sendable { emit: (sending Output) -> AsyncStream.Continuation.YieldResult ) -> Bool + /// Called after the token loop finishes, before the info event. + mutating func onGenerationEnd( + emit: (sending Output) -> AsyncStream.Continuation.YieldResult + ) + func infoEvent(_ info: GenerateCompletionInfo) -> Output } @@ -1337,6 +1344,18 @@ private struct TextToolTokenLoopHandler: TokenLoopHandler, @unchecked Sendable { true } + mutating func onGenerationEnd( + emit: (sending Generation) -> AsyncStream.Continuation.YieldResult + ) { + toolCallProcessor.processEOS() + + for toolCall in toolCallProcessor.toolCalls { + if case .terminated = emit(.toolCall(toolCall)) { + break + } + } + } + func infoEvent(_ info: GenerateCompletionInfo) -> Generation { .info(info) } @@ -1365,6 +1384,10 @@ private struct RawTokenLoopHandler: TokenLoopHandler { return true } + mutating func onGenerationEnd( + emit: (sending TokenGeneration) -> AsyncStream.Continuation.YieldResult + ) {} + func infoEvent(_ info: GenerateCompletionInfo) -> TokenGeneration { .info(info) } diff --git a/Libraries/MLXLMCommon/Tool/Parsers/MistralToolCallParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/MistralToolCallParser.swift new file mode 100644 index 00000000..b9aebcdc --- /dev/null +++ b/Libraries/MLXLMCommon/Tool/Parsers/MistralToolCallParser.swift @@ -0,0 +1,63 @@ +// Copyright © 2025 Apple Inc. + +import Foundation + +/// Parser for Mistral V13 tool call format: `[TOOL_CALLS]name[ARGS]{"json_args"}` +/// +/// This format is used by Mistral3/Ministral-3 2512 models and Devstral 2. +/// The special tokens `[TOOL_CALLS]` (token ID 9) and `[ARGS]` (ID 32) are used +/// as delimiters. Multiple tool calls use repeated `[TOOL_CALLS]` tokens. +/// +/// Also handles the older V11 format which includes an optional `[CALL_ID]` +/// between the function name and `[ARGS]` (V13 does not use `[CALL_ID]`). +/// +/// Examples: +/// - `[TOOL_CALLS]get_weather[ARGS]{"location": "Tokyo"}` +/// - `[TOOL_CALLS]fn1[ARGS]{...}[TOOL_CALLS]fn2[ARGS]{...}` (multiple calls) +/// +/// Mistral does not use an end tag — tool calls end at EOS. Since stop tokens +/// are intercepted at the token ID level before detokenization, tool calls are +/// extracted via `ToolCallProcessor.processEOS()` at generation end. +public struct MistralToolCallParser: ToolCallParser, Sendable { + public let startTag: String? = "[TOOL_CALLS]" + public let endTag: String? = nil + + public init() {} + + public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? { + var text = content.trimmingCharacters(in: .whitespacesAndNewlines) + + // Strip wrapper tags only when they appear at boundaries. + // This keeps literal tag strings inside argument values intact. + if let start = startTag, text.hasPrefix(start) { + text = String(text.dropFirst(start.count)) + text = text.trimmingCharacters(in: .whitespacesAndNewlines) + } + // Split on [ARGS] to get function name and arguments + guard let argsRange = text.range(of: "[ARGS]") else { + return nil + } + + var namePart = String(text[.. ToolCall? + + /// Parse remaining buffered content at end-of-sequence. + /// + /// Called when generation ends to extract any tool calls still in the buffer. + /// The default implementation splits on `startTag` (if present) and parses + /// each segment individually. + func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] +} + +extension ToolCallParser { + public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] { + if let startTag { + return + toolCallBuffer + .components(separatedBy: startTag) + .filter { !$0.isEmpty } + .compactMap { parse(content: $0, tools: tools) } + } else { + guard let toolCall = parse(content: toolCallBuffer, tools: tools) else { + return [] + } + return [toolCall] + } + } } // MARK: - ToolCallFormat Enum @@ -66,6 +90,10 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// Example: `v` case minimaxM2 = "minimax_m2" + /// Mistral V11+ format with [TOOL_CALLS] and [ARGS] delimiters. + /// Example: `[TOOL_CALLS]get_weather [ARGS]{"location": "Tokyo"}` + case mistral + // MARK: - Factory Methods /// Create the appropriate parser for this format. @@ -87,6 +115,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return KimiK2ToolCallParser() case .minimaxM2: return MiniMaxM2ToolCallParser() + case .mistral: + return MistralToolCallParser() } } @@ -115,6 +145,11 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return .gemma } + // Mistral3 family (mistral3, mistral3_text, etc.) + if type.hasPrefix("mistral3") { + return .mistral + } + return nil } } diff --git a/Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift b/Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift index dd5df7fe..6589952a 100644 --- a/Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift +++ b/Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift @@ -55,9 +55,9 @@ public class ToolCallProcessor { // MARK: - Computed Properties - /// Whether this processor uses inline format (no start/end tags). + /// Whether this processor uses inline format (no start tag). private var isInlineFormat: Bool { - parser.startTag == nil || parser.endTag == nil + parser.startTag == nil } /// The first character of the start tag for quick detection. @@ -77,6 +77,27 @@ public class ToolCallProcessor { return processTaggedChunk(chunk) } + /// Process end-of-sequence, parsing any buffered content as tool call(s). + /// + /// Call this when generation ends (e.g., on EOS token) to handle formats + /// whose end tag is never delivered as text (e.g., Mistral where `` + /// is intercepted at the token ID level). + /// + /// For formats with end tags that appear in the text stream, the buffer + /// will already be empty at generation end, making this a no-op. + public func processEOS() { + guard state == .collectingToolCall || state == .potentialToolCall else { return } + guard !toolCallBuffer.isEmpty else { + state = .normal + return + } + + toolCalls.append(contentsOf: parser.parseEOS(toolCallBuffer, tools: tools)) + + toolCallBuffer = "" + state = .normal + } + // MARK: - Private Methods /// Process chunk for inline formats (no wrapper tags). diff --git a/Libraries/MLXVLM/Models/Mistral3.swift b/Libraries/MLXVLM/Models/Mistral3.swift index 2b8d634d..d64bac35 100644 --- a/Libraries/MLXVLM/Models/Mistral3.swift +++ b/Libraries/MLXVLM/Models/Mistral3.swift @@ -1003,7 +1003,11 @@ public struct Mistral3VLMProcessor: UserInputProcessor { if input.images.isEmpty { // No image - just apply chat template - let promptTokens = try tokenizer.applyChatTemplate(messages: messages) + let promptTokens = try tokenizer.applyChatTemplate( + messages: messages, + tools: input.tools, + additionalContext: input.additionalContext + ) let tokensArray = MLXArray(promptTokens).expandedDimensions(axis: 0) let mask = ones(like: tokensArray) return LMInput(text: .init(tokens: tokensArray, mask: mask), image: nil) @@ -1016,7 +1020,11 @@ public struct Mistral3VLMProcessor: UserInputProcessor { let patchSize = config.imageProcessor.patchSize // Apply chat template to get tokenized prompt with image placeholder - var promptTokens = try tokenizer.applyChatTemplate(messages: messages) + var promptTokens = try tokenizer.applyChatTemplate( + messages: messages, + tools: input.tools, + additionalContext: input.additionalContext + ) // Decode to find and replace image placeholder token let decoded = tokenizer.decode(tokens: promptTokens, skipSpecialTokens: false) diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 42e594e8..b64084ba 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -329,6 +329,11 @@ public final class VLMModelFactory: ModelFactory { var mutableConfiguration = configuration mutableConfiguration.eosTokenIds = eosTokenIds + // Auto-detect tool call format from model type if not explicitly set + if mutableConfiguration.toolCallFormat == nil { + mutableConfiguration.toolCallFormat = ToolCallFormat.infer(from: baseConfig.modelType) + } + // Load tokenizer, processor config, and weights in parallel using async let. // Note: loadProcessorConfig does synchronous I/O but is marked async to enable // parallel scheduling. This may briefly block a cooperative thread pool thread, diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index 19dcac7e..3f40d988 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -4,6 +4,7 @@ import Foundation import MLX import MLXLLM import MLXLMCommon +import MLXVLM import XCTest /// Integration tests for tool call format auto-detection and end-to-end parsing. @@ -21,11 +22,13 @@ public class ToolCallIntegrationTests: XCTestCase { static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" + static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" // MARK: - Shared State nonisolated(unsafe) static var lfm2Container: ModelContainer? nonisolated(unsafe) static var glm4Container: ModelContainer? + nonisolated(unsafe) static var mistral3Container: ModelContainer? // MARK: - Tool Schema @@ -61,6 +64,7 @@ public class ToolCallIntegrationTests: XCTestCase { let lfm2Expectation = XCTestExpectation(description: "Load LFM2") let glm4Expectation = XCTestExpectation(description: "Load GLM4") + let mistral3Expectation = XCTestExpectation(description: "Load Mistral3") Task { do { @@ -84,7 +88,19 @@ public class ToolCallIntegrationTests: XCTestCase { glm4Expectation.fulfill() } - _ = XCTWaiter.wait(for: [lfm2Expectation, glm4Expectation], timeout: 600) + Task { + do { + mistral3Container = try await VLMModelFactory.shared.loadContainer( + configuration: .init(id: mistral3ModelId) + ) + } catch { + print("Failed to load Mistral3: \(error)") + } + mistral3Expectation.fulfill() + } + + _ = XCTWaiter.wait( + for: [lfm2Expectation, glm4Expectation, mistral3Expectation], timeout: 600) } // MARK: - LFM2 Tests @@ -195,6 +211,120 @@ public class ToolCallIntegrationTests: XCTestCase { } } + // MARK: - Mistral3 Tests + + func testMistral3ToolCallFormatAutoDetection() async throws { + guard let container = Self.mistral3Container else { + throw XCTSkip("Mistral3 model not available") + } + + let config = await container.configuration + XCTAssertEqual( + config.toolCallFormat, .mistral, + "Mistral3 model should auto-detect .mistral tool call format" + ) + } + + func testMistral3EndToEndToolCallGeneration() async throws { + guard let container = Self.mistral3Container else { + throw XCTSkip("Mistral3 model not available") + } + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." + ), + .user("What's the weather in Tokyo?"), + ], + tools: Self.weatherToolSchema + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 100 + ) + + print("Mistral3 Output: \(result)") + print("Mistral3 Tool Calls: \(toolCalls)") + + // Verify we got a tool call (model may or may not call the tool) + if !toolCalls.isEmpty { + let toolCall = toolCalls.first! + XCTAssertEqual(toolCall.function.name, "get_weather") + if let location = toolCall.function.arguments["location"]?.asString { + XCTAssertTrue( + location.lowercased().contains("tokyo"), + "Expected location to contain 'Tokyo', got: \(location)" + ) + } + } + } + + func testMistral3MultipleToolCallGeneration() async throws { + guard let container = Self.mistral3Container else { + throw XCTSkip("Mistral3 model not available") + } + + let multiToolSchema: [[String: any Sendable]] = + Self.weatherToolSchema + [ + [ + "type": "function", + "function": [ + "name": "get_time", + "description": "Get the current time in a given timezone", + "parameters": [ + "type": "object", + "properties": [ + "timezone": [ + "type": "string", + "description": + "The timezone, e.g. America/New_York, Asia/Tokyo", + ] as [String: any Sendable] + ] as [String: any Sendable], + "required": ["timezone"], + ] as [String: any Sendable], + ] as [String: any Sendable], + ] + ] + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed." + ), + .user( + "What's the weather in Tokyo and what time is it there?" + ), + ], + tools: multiToolSchema + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 150 + ) + + print("Mistral3 Output: \(result)") + print("Mistral3 Calls: \(toolCalls)") + + // Verify all returned tool calls have valid names from our schema + let validNames: Set = ["get_weather", "get_time"] + for toolCall in toolCalls { + XCTAssertTrue( + validNames.contains(toolCall.function.name), + "Unexpected tool call: \(toolCall.function.name)" + ) + } + + // If the model made multiple calls, verify we got more than one + if toolCalls.count > 1 { + print("Successfully parsed \(toolCalls.count) tool calls from Mistral3") + } + } + // MARK: - Helper Methods /// Generate text and collect any tool calls diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift index 634c2913..495449f5 100644 --- a/Tests/MLXLMTests/ToolTests.swift +++ b/Tests/MLXLMTests/ToolTests.swift @@ -422,6 +422,7 @@ struct ToolTests { #expect(ToolCallFormat.gemma.rawValue == "gemma") #expect(ToolCallFormat.kimiK2.rawValue == "kimi_k2") #expect(ToolCallFormat.minimaxM2.rawValue == "minimax_m2") + #expect(ToolCallFormat.mistral.rawValue == "mistral") // Test round-trip via raw value for format in ToolCallFormat.allCases { @@ -452,9 +453,137 @@ struct ToolTests { #expect(ToolCallFormat.infer(from: "gemma") == .gemma) #expect(ToolCallFormat.infer(from: "GEMMA") == .gemma) - // Unknown models should return nil (use default) + // Mistral3 models (prefix matching) + #expect(ToolCallFormat.infer(from: "mistral3") == .mistral) + #expect(ToolCallFormat.infer(from: "Mistral3") == .mistral) + #expect(ToolCallFormat.infer(from: "mistral3_text") == .mistral) + + // Unknown models should return nil (use default JSON format) #expect(ToolCallFormat.infer(from: "llama") == nil) #expect(ToolCallFormat.infer(from: "qwen2") == nil) #expect(ToolCallFormat.infer(from: "mistral") == nil) } + + // MARK: - Mistral Format Tests + + @Test("Test Mistral Tool Call Parser") + func testMistralParser() throws { + let parser = MistralToolCallParser() + let content = "[TOOL_CALLS]get_weather [ARGS]{\"location\": \"Paris\"}" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Paris")) + } + + @Test("Test Mistral Tool Call Parser - With Call ID") + func testMistralParserWithCallId() throws { + let parser = MistralToolCallParser() + let content = "[TOOL_CALLS]get_weather[CALL_ID]abc123xyz[ARGS]{\"location\": \"Paris\"}" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Paris")) + } + + @Test("Test Mistral Tool Call Parser - Preserves [TOOL_CALLS] in Arguments") + func testMistralParserPreservesStartTagInArguments() throws { + let parser = MistralToolCallParser() + let content = "get_note[ARGS]{\"text\": \"literal [TOOL_CALLS] marker\"}" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_note") + #expect(toolCall.function.arguments["text"] == .string("literal [TOOL_CALLS] marker")) + } + + @Test("Test Mistral Tool Call Parser - Preserves in Arguments") + func testMistralParserPreservesEndTagInArguments() throws { + let parser = MistralToolCallParser() + let content = "get_note[ARGS]{\"text\": \"literal marker\"}" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_note") + #expect(toolCall.function.arguments["text"] == .string("literal marker")) + } + + @Test("Test Mistral Format via ToolCallProcessor") + func testMistralFormatProcessor() throws { + let processor = ToolCallProcessor(format: .mistral) + let chunks: [String] = [ + "[TOOL", "_CALLS]", "get_weather", " [ARGS]", + "{\"location\":", " \"Tokyo\"}", + ] + + for chunk in chunks { + _ = processor.processChunk(chunk) + } + + // End tag never arrives in text, so tool call stays buffered until processEOS + #expect(processor.toolCalls.count == 0) + + processor.processEOS() + + #expect(processor.toolCalls.count == 1) + let toolCall = try #require(processor.toolCalls.first) + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Tokyo")) + } + + @Test("Test Mistral Format Processor EOS") + func testMistralFormatProcessorEOS() throws { + let processor = ToolCallProcessor(format: .mistral) + let content = "[TOOL_CALLS]get_weather [ARGS]{\"location\": \"Berlin\"}" + + _ = processor.processChunk(content) + + // Before processEOS, no tool calls extracted (end tag never arrives) + #expect(processor.toolCalls.count == 0) + + // processEOS extracts the buffered tool call + processor.processEOS() + + #expect(processor.toolCalls.count == 1) + let toolCall = try #require(processor.toolCalls.first) + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Berlin")) + } + + @Test("Test Mistral Format Processor Multiple Tool Calls") + func testMistralFormatProcessorMultipleToolCalls() throws { + let processor = ToolCallProcessor(format: .mistral) + let chunks: [String] = [ + "[TOOL_CALLS]get_weather[ARGS]", + "{\"location\": \"Paris\"}", + "[TOOL_CALLS]get_time", + "[ARGS]{\"timezone\": \"UTC\"}", + ] + + for chunk in chunks { + let result = processor.processChunk(chunk) + // All chunks should be buffered (nil) after the start tag + if chunk == chunks.first { + #expect(result == nil) + } + } + + // No tool calls before processEOS + #expect(processor.toolCalls.count == 0) + + processor.processEOS() + + // Both tool calls should be extracted + #expect(processor.toolCalls.count == 2) + + let first = try #require(processor.toolCalls.first) + #expect(first.function.name == "get_weather") + #expect(first.function.arguments["location"] == .string("Paris")) + + let second = processor.toolCalls[1] + #expect(second.function.name == "get_time") + #expect(second.function.arguments["timezone"] == .string("UTC")) + } } diff --git a/Tests/MLXLMTests/UserInputTests.swift b/Tests/MLXLMTests/UserInputTests.swift index 4f113ee0..ecb2d61f 100644 --- a/Tests/MLXLMTests/UserInputTests.swift +++ b/Tests/MLXLMTests/UserInputTests.swift @@ -95,6 +95,96 @@ public class UserInputTests: XCTestCase { assertEqual(expected, messages) } + // MARK: - Mistral3 Message Generator Tests + + public func testMistral3ConversionText() { + let chat: [Chat.Message] = [ + .system("You are a useful agent."), + .user("Tell me a story."), + ] + + let messages = Mistral3MessageGenerator().generate(messages: chat) + + let expected: [[String: any Sendable]] = [ + [ + "role": "system", + "content": [ + [ + "type": "text", + "text": "You are a useful agent.", + ] + ], + ], + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "Tell me a story.", + ] + ], + ], + ] + + assertEqual(expected, messages) + } + + public func testMistral3ConversionWithImage() { + let chat: [Chat.Message] = [ + .user( + "What is this?", + images: [ + .url( + URL( + string: "https://opensource.apple.com/images/projects/mlx.f5c59d8b.png")! + ) + ]) + ] + + let messages = Mistral3MessageGenerator().generate(messages: chat) + + let expected: [[String: any Sendable]] = [ + [ + "role": "user", + "content": [ + [ + "type": "image" + ], + [ + "type": "text", + "text": "What is this?", + ], + ], + ] + ] + + assertEqual(expected, messages) + } + + public func testMistral3ConversionToolRole() { + let chat: [Chat.Message] = [ + .tool("The weather is sunny, 14°C.") + ] + + let messages = Mistral3MessageGenerator().generate(messages: chat) + + let expected: [[String: any Sendable]] = [ + [ + "role": "tool", + "content": [ + [ + "type": "text", + "text": "The weather is sunny, 14°C.", + ] + ], + ] + ] + + assertEqual(expected, messages) + } + + // MARK: - Qwen2 Message Generator Tests + public func testQwen2ConversionImage() { let chat: [Chat.Message] = [ .system("You are a useful agent."),