Skip to content

Commit 7d4d9ec

Browse files
authored
Enable ModernBert on metal (#562)
1 parent f99ce07 commit 7d4d9ec

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

backends/candle/src/lib.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,19 +277,12 @@ impl CandleBackend {
277277
tracing::info!("Starting MPNet model on {:?}", device);
278278
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
279279
}
280-
(Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => match device {
281-
Device::Metal(_) => {
282-
return Err(BackendError::Start(
283-
"ModernBert is not currently supported on MPS device".to_string(),
284-
));
285-
}
286-
_ => {
287-
tracing::info!("Starting ModernBert model on {:?}", device);
288-
Ok(Box::new(
289-
ModernBertModel::load(vb, &config, model_type).s()?,
290-
))
291-
}
292-
},
280+
(Config::ModernBert(config), Device::Cpu | Device::Metal(_)) => {
281+
tracing::info!("Starting ModernBert model on {:?}", device);
282+
Ok(Box::new(
283+
ModernBertModel::load(vb, &config, model_type).s()?,
284+
))
285+
}
293286
#[cfg(feature = "cuda")]
294287
(Config::Bert(config), Device::Cuda(_)) => {
295288
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))

0 commit comments

Comments
 (0)