Skip to content

Commit 8f981a1

Browse files
committed
Add Silu variant in HiddenAct and remove serde alias
1 parent 0aef8db commit 8f981a1

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

backends/candle/src/layers/linear.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use serde::Deserialize;
77
pub enum HiddenAct {
88
Gelu,
99
Relu,
10-
#[serde(alias = "silu")]
10+
Silu,
1111
Swiglu,
1212
}
1313

@@ -16,9 +16,8 @@ impl HiddenAct {
1616
match self {
1717
Self::Gelu => x.gelu(),
1818
Self::Relu => x.relu(),
19-
// NOTE: use SiLU instead candle's SwiGLU, as SwiGLU is SiLU + down projection
20-
// to half size since we split on intermediate dimension
21-
Self::Swiglu => x.silu(),
19+
Self::Silu => x.silu(),
20+
Self::Swiglu => candle_nn::ops::swiglu(&x),
2221
}
2322
}
2423
}
@@ -82,9 +81,8 @@ impl Linear {
8281
match act {
8382
HiddenAct::Gelu => x.gelu(),
8483
HiddenAct::Relu => x.relu(),
85-
// NOTE: use SiLU instead candle's SwiGLU, as SwiGLU is SiLU + down projection
86-
// to half size since we split on intermediate dimension
87-
HiddenAct::Swiglu => x.silu(),
84+
HiddenAct::Silu => x.silu(),
85+
HiddenAct::Swiglu => candle_nn::ops::swiglu(&x),
8886
}
8987
} else {
9088
Ok(x)

0 commit comments

Comments
 (0)