Skip to content

Commit 656e3ff

Browse files
committed
Add FlashQwen3Model (WIP)
1 parent d51a8b9 commit 656e3ff

File tree

4 files changed

+508
-2
lines changed

4 files changed

+508
-2
lines changed

backends/candle/src/lib.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ use crate::compute_cap::{
1313
use crate::models::{
1414
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
1515
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
16-
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
16+
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config,
1717
};
1818
#[cfg(feature = "cuda")]
1919
use crate::models::{
2020
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
2121
FlashJinaCodeBertModel, FlashMistralModel, FlashModernBertModel, FlashNomicBertModel,
22-
FlashQwen2Model,
22+
FlashQwen2Model, FlashQwen3Model,
2323
};
2424
use anyhow::Context;
2525
use candle::{DType, Device};
@@ -446,6 +446,18 @@ impl CandleBackend {
446446
FlashQwen2Model::load(vb, &config, model_type).s()?,
447447
))
448448
}
449+
#[cfg(feature = "cuda")]
450+
(Config::Qwen3(config), Device::Cuda(_)) => {
451+
if dtype != DType::F16
452+
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
453+
{
454+
return Err(BackendError::Start("Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
455+
}
456+
tracing::info!("Starting FlashQwen3 model on {:?}", device);
457+
Ok(Box::new(
458+
FlashQwen3Model::load(vb, &config, model_type).s()?,
459+
))
460+
}
449461
};
450462

451463
Ok(Self {

0 commit comments

Comments
 (0)