@@ -13,13 +13,13 @@ use crate::compute_cap::{
13
13
use crate :: models:: {
14
14
BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , GTEModel , JinaBertModel ,
15
15
JinaCodeBertModel , MPNetConfig , MPNetModel , MistralConfig , Model , ModernBertConfig ,
16
- ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config ,
16
+ ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config , Qwen3Config ,
17
17
} ;
18
18
#[ cfg( feature = "cuda" ) ]
19
19
use crate :: models:: {
20
20
FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
21
21
FlashJinaCodeBertModel , FlashMistralModel , FlashModernBertModel , FlashNomicBertModel ,
22
- FlashQwen2Model ,
22
+ FlashQwen2Model , FlashQwen3Model ,
23
23
} ;
24
24
use anyhow:: Context ;
25
25
use candle:: { DType , Device } ;
@@ -446,6 +446,18 @@ impl CandleBackend {
446
446
FlashQwen2Model :: load ( vb, & config, model_type) . s ( ) ?,
447
447
) )
448
448
}
449
+ #[ cfg( feature = "cuda" ) ]
450
+ ( Config :: Qwen3 ( config) , Device :: Cuda ( _) ) => {
451
+ if dtype != DType :: F16
452
+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
453
+ {
454
+ return Err ( BackendError :: Start ( "Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
455
+ }
456
+ tracing:: info!( "Starting FlashQwen3 model on {:?}" , device) ;
457
+ Ok ( Box :: new (
458
+ FlashQwen3Model :: load ( vb, & config, model_type) . s ( ) ?,
459
+ ) )
460
+ }
449
461
} ;
450
462
451
463
Ok ( Self {
0 commit comments