@@ -193,7 +193,6 @@ impl GTEMLP {
193193 let up_gate_proj_weight = vb
194194 . pp ( "up_gate_proj" )
195195 . get ( ( intermediate_size * 2 , config. hidden_size ) , "weight" ) ?;
196-
197196 let up_gate_proj = Linear :: new ( up_gate_proj_weight, None , None ) ;
198197
199198 let down_proj_weight = vb
@@ -216,16 +215,12 @@ impl GTEMLP {
216215
217216 let up_gate_states = self . up_gate_proj . forward ( hidden_states) ?;
218217 let up_states = up_gate_states. narrow ( D :: Minus1 , 0 , self . intermediate_size ) ?;
219- let gate_states =
220- up_gate_states. narrow ( D :: Minus1 , self . intermediate_size , self . intermediate_size ) ?;
221218
222- let gate_states = match self . act {
223- HiddenAct :: Gelu => gate_states. gelu ( ) ,
224- HiddenAct :: Relu => gate_states. relu ( ) ,
225- HiddenAct :: Swiglu => gate_states. silu ( ) ,
226- } ?;
219+ let gate =
220+ up_gate_states. narrow ( D :: Minus1 , self . intermediate_size , self . intermediate_size ) ?;
221+ let gate = self . act . forward ( & gate) ?;
227222
228- self . down_proj . forward ( & ( gate_states * up_states) ?)
223+ self . down_proj . forward ( & ( gate * up_states) ?)
229224 }
230225}
231226
@@ -288,22 +283,25 @@ pub struct GTEClassificationHead {
288283}
289284
290285impl GTEClassificationHead {
291- #[ allow( dead_code) ]
286+ fn inner_load ( vb : VarBuilder , config : & GTEConfig ) -> Option < Linear > {
287+ let pooler_weight = vb
288+ . pp ( "pooler.dense" )
289+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
290+ . ok ( ) ?;
291+ let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) . ok ( ) ?;
292+ let pooler = Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) ;
293+
294+ Some ( pooler)
295+ }
296+
292297 pub ( crate ) fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
293298 let n_classes = match & config. id2label {
294299 None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
295300 Some ( id2label) => id2label. len ( ) ,
296301 } ;
297302
298- let pooler = if let Ok ( pooler_weight) = vb
299- . pp ( "pooler.dense" )
300- . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
301- {
302- let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
303- Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
304- } else {
305- None
306- } ;
303+ let pooler =
304+ Self :: inner_load ( vb. pp ( "new" ) , config) . or_else ( || Self :: inner_load ( vb. clone ( ) , config) ) ;
307305
308306 let classifier_weight = vb
309307 . pp ( "classifier" )
@@ -322,13 +320,15 @@ impl GTEClassificationHead {
322320 let _enter = self . span . enter ( ) ;
323321
324322 let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
323+
325324 if let Some ( pooler) = self . pooler . as_ref ( ) {
326325 hidden_states = pooler. forward ( & hidden_states) ?;
327326 hidden_states = hidden_states. tanh ( ) ?;
328327 }
329328
330329 let hidden_states = self . classifier . forward ( & hidden_states) ?;
331330 let hidden_states = hidden_states. squeeze ( 1 ) ?;
331+
332332 Ok ( hidden_states)
333333 }
334334}
0 commit comments