diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index 4c875a89..3432634a 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -7,7 +7,7 @@ use serde::Deserialize; pub enum HiddenAct { Gelu, Relu, - #[serde(alias = "silu")] + Silu, Swiglu, } @@ -16,6 +16,7 @@ impl HiddenAct { match self { Self::Gelu => x.gelu(), Self::Relu => x.relu(), + Self::Silu => x.silu(), Self::Swiglu => candle_nn::ops::swiglu(x), } } @@ -80,6 +81,7 @@ impl Linear { match act { HiddenAct::Gelu => x.gelu(), HiddenAct::Relu => x.relu(), + HiddenAct::Silu => x.silu(), HiddenAct::Swiglu => candle_nn::ops::swiglu(&x), } } else { diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 94878c5a..05341b84 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -176,11 +176,7 @@ impl JinaBertLayer { let hidden_states = self.gated.forward(&hidden_states)?; let gated = hidden_states.narrow(1, 0, self.intermediate_size)?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let non_gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?; let hidden_states = (gated * non_gated)?; diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index 745786dc..e00f758d 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -230,11 +230,7 @@ impl JinaCodeBertLayer { let hidden_states = self.up_gated_layer.forward(&hidden_states)?; let non_gated = hidden_states.narrow(1, 0, self.intermediate_size)?; let gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let hidden_states = (non_gated * gated)?; let hidden_states = self.down_layer.forward(&hidden_states)?; diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index 19955259..c8488f36 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -159,11 +159,7 @@ impl MistralMLP { let gate_states = gate_up_states.narrow(1, 0, self.intermediate_size)?; let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; + let gate_states = self.act.forward(&gate_states)?; let r = self.down_proj.forward(&(gate_states * up_states)?); r } diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index 904767ea..c9116311 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -167,11 +167,7 @@ impl Qwen2MLP { let gate_states = gate_up_states.narrow(1, 0, self.intermediate_size)?; let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; + let gate_states = self.act.forward(&gate_states)?; let r = self.down_proj.forward(&(gate_states * up_states)?); r } diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index b694befb..d54d49c6 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -294,11 +294,7 @@ impl JinaBertLayer { let hidden_states = self.gated.forward(&hidden_states)?; let gated = hidden_states.i((.., .., 0..self.intermediate_size))?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let non_gated = hidden_states.i((.., .., self.intermediate_size..))?; let hidden_states = (gated * non_gated)?; diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 8dadea00..8cb6f65a 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -284,11 +284,7 @@ impl JinaCodeBertLayer { let hidden_states = self.up_gated_layer.forward(&hidden_states)?; let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; let gated = hidden_states.i((.., .., self.intermediate_size..))?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let hidden_states = (non_gated * gated)?; let hidden_states = self.down_layer.forward(&hidden_states)?; diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index 046a1547..b94325ca 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -120,11 +120,7 @@ impl ModernBertMLP { hidden_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; let input = if let Some(activation) = &self.activation { - match activation { - HiddenAct::Gelu => input.gelu(), - HiddenAct::Relu => input.relu(), - HiddenAct::Swiglu => input.silu(), - } + activation.forward(&input) } else { Ok(input) };