diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index cdce8188303..d92adc4042f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -6,7 +6,7 @@ use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::Tool; use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, - Message, PrefillToken, Token, + Message, PrefillToken, TextMessage, Token, }; use async_stream::stream; use async_trait::async_trait; @@ -68,17 +68,29 @@ impl Infer { tokenizer_config: HubTokenizerConfig, processor_config: HubProcessorConfig, ) -> Self { - let chat_template = tokenizer_config - .chat_template - .or(processor_config.chat_template) - .and_then(|t| match t { - ChatTemplateVersions::Single(template) => Some(template), - ChatTemplateVersions::Multiple(templates) => templates - .into_iter() - .find(|t| t.name == "default") - .map(|t| t.template), - }) - .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); + let chat_template = if matches!( + processor_config.processor_class.as_deref(), + Some("Llama4Processor") + | Some("LlavaNextProcessor") + | Some("Idefics2Processor") + | Some("Idefics3Processor") + ) { + None // Do not use chat_template + } else { + tokenizer_config + .chat_template + .or(processor_config.chat_template) + .and_then(|t| match t { + ChatTemplateVersions::Single(template) => Some(template), + ChatTemplateVersions::Multiple(templates) => templates + .into_iter() + .find(|t| t.name == "default") + .map(|t| t.template), + }) + .map(|t| { + ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) + }) + }; // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -229,6 +241,15 @@ impl Infer { messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { + if self.chat_template.is_none() { + let textmessages: Vec = messages.into_iter().map(|c| c.into()).collect(); + let message_str = textmessages + .iter() + .map(|msg| msg.content.clone()) // Extract content from each `TextMessage` + .collect::>() // Collect all content into a vector + .join("\n"); // Join all content into a single string separated by newlines + return Ok(message_str); + } self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc22e8..0e51bb25aa4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -210,7 +210,7 @@ pub struct Llama4Processor { #[derive(Debug, Clone, Deserialize, Default)] pub struct HubProcessorConfig { pub chat_template: Option, - pub image_seq_len: usize, + pub image_seq_len: Option, pub processor_class: Option, } @@ -1008,7 +1008,7 @@ impl ChatRequest { Ok(( GenerateRequest { inputs: inputs.to_string(), - add_special_tokens: false, + add_special_tokens: infer.chat_template.is_none(), parameters: GenerateParameters { best_of: None, temperature,