Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
}
}

handler.onGenerationEnd(emit: continuation.yield)

let now = Date.timeIntervalSinceReferenceDate
let generateTime = now - start

Expand Down Expand Up @@ -1292,6 +1294,11 @@ private protocol TokenLoopHandler: Sendable {
emit: (sending Output) -> AsyncStream<Output>.Continuation.YieldResult
) -> Bool

/// Called after the token loop finishes, before the info event.
mutating func onGenerationEnd(
emit: (sending Output) -> AsyncStream<Output>.Continuation.YieldResult
)

func infoEvent(_ info: GenerateCompletionInfo) -> Output
}

Expand Down Expand Up @@ -1337,6 +1344,18 @@ private struct TextToolTokenLoopHandler: TokenLoopHandler, @unchecked Sendable {
true
}

mutating func onGenerationEnd(
emit: (sending Generation) -> AsyncStream<Generation>.Continuation.YieldResult
) {
toolCallProcessor.processEOS()

for toolCall in toolCallProcessor.toolCalls {
if case .terminated = emit(.toolCall(toolCall)) {
break
}
}
}

func infoEvent(_ info: GenerateCompletionInfo) -> Generation {
.info(info)
}
Expand Down Expand Up @@ -1365,6 +1384,10 @@ private struct RawTokenLoopHandler: TokenLoopHandler {
return true
}

mutating func onGenerationEnd(
emit: (sending TokenGeneration) -> AsyncStream<TokenGeneration>.Continuation.YieldResult
) {}

func infoEvent(_ info: GenerateCompletionInfo) -> TokenGeneration {
.info(info)
}
Expand Down
63 changes: 63 additions & 0 deletions Libraries/MLXLMCommon/Tool/Parsers/MistralToolCallParser.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright © 2025 Apple Inc.

import Foundation

/// Parser for Mistral V13 tool call format: `[TOOL_CALLS]name[ARGS]{"json_args"}`
///
/// This format is used by Mistral3/Ministral-3 2512 models and Devstral 2.
/// The special tokens `[TOOL_CALLS]` (token ID 9) and `[ARGS]` (ID 32) are used
/// as delimiters. Multiple tool calls use repeated `[TOOL_CALLS]` tokens.
///
/// Also handles the older V11 format which includes an optional `[CALL_ID]`
/// between the function name and `[ARGS]` (V13 does not use `[CALL_ID]`).
///
/// Examples:
/// - `[TOOL_CALLS]get_weather[ARGS]{"location": "Tokyo"}`
/// - `[TOOL_CALLS]fn1[ARGS]{...}[TOOL_CALLS]fn2[ARGS]{...}` (multiple calls)
///
/// Mistral does not use an end tag — tool calls end at EOS. Since stop tokens
/// are intercepted at the token ID level before detokenization, tool calls are
/// extracted via `ToolCallProcessor.processEOS()` at generation end.
public struct MistralToolCallParser: ToolCallParser, Sendable {
public let startTag: String? = "[TOOL_CALLS]"
public let endTag: String? = nil

public init() {}

public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
var text = content.trimmingCharacters(in: .whitespacesAndNewlines)

// Strip wrapper tags only when they appear at boundaries.
// This keeps literal tag strings inside argument values intact.
if let start = startTag, text.hasPrefix(start) {
text = String(text.dropFirst(start.count))
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
}
// Split on [ARGS] to get function name and arguments
guard let argsRange = text.range(of: "[ARGS]") else {
return nil
}

var namePart = String(text[..<argsRange.lowerBound])
.trimmingCharacters(in: .whitespacesAndNewlines)
let argsPart = String(text[argsRange.upperBound...])
.trimmingCharacters(in: .whitespacesAndNewlines)

// Handle optional [CALL_ID] between name and [ARGS]
if let callIdRange = namePart.range(of: "[CALL_ID]") {
namePart = String(namePart[..<callIdRange.lowerBound])
.trimmingCharacters(in: .whitespacesAndNewlines)
}

guard !namePart.isEmpty else { return nil }

// Parse arguments as JSON using tryParseJSON from ParserUtilities
guard let argsDict = tryParseJSON(argsPart) as? [String: any Sendable] else {
return nil
}

return ToolCall(
function: ToolCall.Function(name: namePart, arguments: argsDict)
)
}
}
35 changes: 35 additions & 0 deletions Libraries/MLXLMCommon/Tool/ToolCallFormat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ public protocol ToolCallParser: Sendable {
/// - tools: Optional tool schemas for type-aware parsing
/// - Returns: A `ToolCall` if parsing succeeds, `nil` otherwise
func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall?

/// Parse remaining buffered content at end-of-sequence.
///
/// Called when generation ends to extract any tool calls still in the buffer.
/// The default implementation splits on `startTag` (if present) and parses
/// each segment individually.
func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall]
}

extension ToolCallParser {
public func parseEOS(_ toolCallBuffer: String, tools: [[String: any Sendable]]?) -> [ToolCall] {
if let startTag {
return
toolCallBuffer
.components(separatedBy: startTag)
.filter { !$0.isEmpty }
.compactMap { parse(content: $0, tools: tools) }
} else {
guard let toolCall = parse(content: toolCallBuffer, tools: tools) else {
return []
}
return [toolCall]
}
}
}

// MARK: - ToolCallFormat Enum
Expand Down Expand Up @@ -66,6 +90,10 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
/// Example: `<invoke name="f"><parameter name="k">v</parameter></invoke>`
case minimaxM2 = "minimax_m2"

/// Mistral V11+ format with [TOOL_CALLS] and [ARGS] delimiters.
/// Example: `[TOOL_CALLS]get_weather [ARGS]{"location": "Tokyo"}`
case mistral

// MARK: - Factory Methods

/// Create the appropriate parser for this format.
Expand All @@ -87,6 +115,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
return KimiK2ToolCallParser()
case .minimaxM2:
return MiniMaxM2ToolCallParser()
case .mistral:
return MistralToolCallParser()
}
}

Expand Down Expand Up @@ -115,6 +145,11 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
return .gemma
}

// Mistral3 family (mistral3, mistral3_text, etc.)
if type.hasPrefix("mistral3") {
return .mistral
}

return nil
}
}
25 changes: 23 additions & 2 deletions Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ public class ToolCallProcessor {

// MARK: - Computed Properties

/// Whether this processor uses inline format (no start/end tags).
/// Whether this processor uses inline format (no start tag).
private var isInlineFormat: Bool {
parser.startTag == nil || parser.endTag == nil
parser.startTag == nil
}

/// The first character of the start tag for quick detection.
Expand All @@ -77,6 +77,27 @@ public class ToolCallProcessor {
return processTaggedChunk(chunk)
}

/// Process end-of-sequence, parsing any buffered content as tool call(s).
///
/// Call this when generation ends (e.g., on EOS token) to handle formats
/// whose end tag is never delivered as text (e.g., Mistral where `</s>`
/// is intercepted at the token ID level).
///
/// For formats with end tags that appear in the text stream, the buffer
/// will already be empty at generation end, making this a no-op.
public func processEOS() {
guard state == .collectingToolCall || state == .potentialToolCall else { return }
guard !toolCallBuffer.isEmpty else {
state = .normal
return
}

toolCalls.append(contentsOf: parser.parseEOS(toolCallBuffer, tools: tools))

toolCallBuffer = ""
state = .normal
}

// MARK: - Private Methods

/// Process chunk for inline formats (no wrapper tags).
Expand Down
12 changes: 10 additions & 2 deletions Libraries/MLXVLM/Models/Mistral3.swift
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file were the initial changes I made. Without these, the Mistral3 models never even received the tools from the caller.

Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,11 @@ public struct Mistral3VLMProcessor: UserInputProcessor {

if input.images.isEmpty {
// No image - just apply chat template
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
let promptTokens = try tokenizer.applyChatTemplate(
messages: messages,
tools: input.tools,
additionalContext: input.additionalContext
)
let tokensArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
let mask = ones(like: tokensArray)
return LMInput(text: .init(tokens: tokensArray, mask: mask), image: nil)
Expand All @@ -1016,7 +1020,11 @@ public struct Mistral3VLMProcessor: UserInputProcessor {
let patchSize = config.imageProcessor.patchSize

// Apply chat template to get tokenized prompt with image placeholder
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)
var promptTokens = try tokenizer.applyChatTemplate(
messages: messages,
tools: input.tools,
additionalContext: input.additionalContext
)

// Decode to find and replace image placeholder token
let decoded = tokenizer.decode(tokens: promptTokens, skipSpecialTokens: false)
Expand Down
5 changes: 5 additions & 0 deletions Libraries/MLXVLM/VLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ public final class VLMModelFactory: ModelFactory {
var mutableConfiguration = configuration
mutableConfiguration.eosTokenIds = eosTokenIds

// Auto-detect tool call format from model type if not explicitly set
if mutableConfiguration.toolCallFormat == nil {
mutableConfiguration.toolCallFormat = ToolCallFormat.infer(from: baseConfig.modelType)
}
Comment on lines +332 to +335
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed like an oversight. It should always have been present, as far as I could tell.


// Load tokenizer, processor config, and weights in parallel using async let.
// Note: loadProcessorConfig does synchronous I/O but is marked async to enable
// parallel scheduling. This may briefly block a cooperative thread pool thread,
Expand Down
Loading
Loading