From 07279c8bf20e250bbef9bfca58f844ec4b38a1a4 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 11 Jun 2025 07:05:55 +0000 Subject: [PATCH 1/6] Patch `HiddenAct::Swiglu` to use SiLU activation --- backends/candle/src/layers/linear.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index 4c875a89..3be1173f 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -16,7 +16,9 @@ impl HiddenAct { match self { Self::Gelu => x.gelu(), Self::Relu => x.relu(), - Self::Swiglu => candle_nn::ops::swiglu(x), + // NOTE: use SiLU instead candle's SwiGLU, as SwiGLU is SiLU + down projection + // to half size since we split on intermediate dimension + Self::Swiglu => x.silu(), } } } @@ -80,7 +82,9 @@ impl Linear { match act { HiddenAct::Gelu => x.gelu(), HiddenAct::Relu => x.relu(), - HiddenAct::Swiglu => candle_nn::ops::swiglu(&x), + // NOTE: use SiLU instead candle's SwiGLU, as SwiGLU is SiLU + down projection + // to half size since we split on intermediate dimension + HiddenAct::Swiglu => x.silu(), } } else { Ok(x) From 9f85bf3b6e8d4ec1aa8c1ae56af3d603c907501b Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 11 Jun 2025 07:06:14 +0000 Subject: [PATCH 2/6] Use `HiddenAct.forward` instead of `match` --- backends/candle/src/models/flash_jina.rs | 6 +----- backends/candle/src/models/flash_jina_code.rs | 6 +----- backends/candle/src/models/flash_mistral.rs | 6 +----- backends/candle/src/models/flash_qwen2.rs | 6 +----- backends/candle/src/models/jina.rs | 6 +----- backends/candle/src/models/jina_code.rs | 6 +----- backends/candle/src/models/modernbert.rs | 6 +----- 7 files changed, 7 insertions(+), 35 deletions(-) diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 94878c5a..05341b84 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -176,11 +176,7 @@ impl JinaBertLayer { let hidden_states = self.gated.forward(&hidden_states)?; let gated = hidden_states.narrow(1, 0, self.intermediate_size)?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let non_gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?; let hidden_states = (gated * non_gated)?; diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index 745786dc..e00f758d 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -230,11 +230,7 @@ impl JinaCodeBertLayer { let hidden_states = self.up_gated_layer.forward(&hidden_states)?; let non_gated = hidden_states.narrow(1, 0, self.intermediate_size)?; let gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let hidden_states = (non_gated * gated)?; let hidden_states = self.down_layer.forward(&hidden_states)?; diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index 19955259..c8488f36 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -159,11 +159,7 @@ impl MistralMLP { let gate_states = gate_up_states.narrow(1, 0, self.intermediate_size)?; let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; + let gate_states = self.act.forward(&gate_states)?; let r = self.down_proj.forward(&(gate_states * up_states)?); r } diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index 904767ea..c9116311 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -167,11 +167,7 @@ impl Qwen2MLP { let gate_states = gate_up_states.narrow(1, 0, self.intermediate_size)?; let up_states = gate_up_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; + let gate_states = self.act.forward(&gate_states)?; let r = self.down_proj.forward(&(gate_states * up_states)?); r } diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index b694befb..d54d49c6 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -294,11 +294,7 @@ impl JinaBertLayer { let hidden_states = self.gated.forward(&hidden_states)?; let gated = hidden_states.i((.., .., 0..self.intermediate_size))?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let non_gated = hidden_states.i((.., .., self.intermediate_size..))?; let hidden_states = (gated * non_gated)?; diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index 8dadea00..8cb6f65a 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -284,11 +284,7 @@ impl JinaCodeBertLayer { let hidden_states = self.up_gated_layer.forward(&hidden_states)?; let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; let gated = hidden_states.i((.., .., self.intermediate_size..))?; - let gated = match self.act { - HiddenAct::Gelu => gated.gelu(), - HiddenAct::Relu => gated.relu(), - HiddenAct::Swiglu => gated.silu(), - }?; + let gated = self.act.forward(&gated)?; let hidden_states = (non_gated * gated)?; let hidden_states = self.down_layer.forward(&hidden_states)?; diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index 046a1547..bc2fc996 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -120,11 +120,7 @@ impl ModernBertMLP { hidden_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; let input = if let Some(activation) = &self.activation { - match activation { - HiddenAct::Gelu => input.gelu(), - HiddenAct::Relu => input.relu(), - HiddenAct::Swiglu => input.silu(), - } + activation.forward(&input)? } else { Ok(input) }; From 0aef8dbe7f33e9bddc50c62e948525535cc00219 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 11 Jun 2025 07:12:36 +0000 Subject: [PATCH 3/6] Fix `activation.forward` in ModernBertMLP --- backends/candle/src/models/modernbert.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index bc2fc996..b94325ca 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -120,7 +120,7 @@ impl ModernBertMLP { hidden_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; let input = if let Some(activation) = &self.activation { - activation.forward(&input)? + activation.forward(&input) } else { Ok(input) }; From 8f981a159536eb96574c783364563fc8cbfa9fd3 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 11 Jun 2025 09:58:21 +0000 Subject: [PATCH 4/6] Add `Silu` variant in `HiddenAct` and remove `serde` alias --- backends/candle/src/layers/linear.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index 3be1173f..824ac161 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -7,7 +7,7 @@ use serde::Deserialize; pub enum HiddenAct { Gelu, Relu, - #[serde(alias = "silu")] + Silu, Swiglu, } @@ -16,9 +16,8 @@ impl HiddenAct { match self { Self::Gelu => x.gelu(), Self::Relu => x.relu(), - // NOTE: use SiLU instead candle's SwiGLU, as SwiGLU is SiLU + down projection - // to half size since we split on intermediate dimension - Self::Swiglu => x.silu(), + Self::Silu => x.silu(), + Self::Swiglu => candle_nn::ops::swiglu(&x), } } } @@ -82,9 +81,8 @@ impl Linear { match act { HiddenAct::Gelu => x.gelu(), HiddenAct::Relu => x.relu(), - // NOTE: use SiLU instead candle's SwiGLU, as SwiGLU is SiLU + down projection - // to half size since we split on intermediate dimension - HiddenAct::Swiglu => x.silu(), + HiddenAct::Silu => x.silu(), + HiddenAct::Swiglu => candle_nn::ops::swiglu(&x), } } else { Ok(x) From 272b2f762d0ad2b44f14a0fa97b0cd653852312c Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 11 Jun 2025 10:08:37 +0000 Subject: [PATCH 5/6] Add kozistr as co-author Co-authored-by: Hyeongchan Kim From 6c134d7db2839e95f533000512096ad31c34d60f Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 11 Jun 2025 10:11:56 +0000 Subject: [PATCH 6/6] Remove reference from `x` in `HiddenAct` match --- backends/candle/src/layers/linear.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/layers/linear.rs b/backends/candle/src/layers/linear.rs index 824ac161..3432634a 100644 --- a/backends/candle/src/layers/linear.rs +++ b/backends/candle/src/layers/linear.rs @@ -17,7 +17,7 @@ impl HiddenAct { Self::Gelu => x.gelu(), Self::Relu => x.relu(), Self::Silu => x.silu(), - Self::Swiglu => candle_nn::ops::swiglu(&x), + Self::Swiglu => candle_nn::ops::swiglu(x), } } }