@@ -193,7 +193,6 @@ impl GTEMLP {
193
193
let up_gate_proj_weight = vb
194
194
. pp ( "up_gate_proj" )
195
195
. get ( ( intermediate_size * 2 , config. hidden_size ) , "weight" ) ?;
196
-
197
196
let up_gate_proj = Linear :: new ( up_gate_proj_weight, None , None ) ;
198
197
199
198
let down_proj_weight = vb
@@ -216,16 +215,12 @@ impl GTEMLP {
216
215
217
216
let up_gate_states = self . up_gate_proj . forward ( hidden_states) ?;
218
217
let up_states = up_gate_states. narrow ( D :: Minus1 , 0 , self . intermediate_size ) ?;
219
- let gate_states =
220
- up_gate_states. narrow ( D :: Minus1 , self . intermediate_size , self . intermediate_size ) ?;
221
218
222
- let gate_states = match self . act {
223
- HiddenAct :: Gelu => gate_states. gelu ( ) ,
224
- HiddenAct :: Relu => gate_states. relu ( ) ,
225
- HiddenAct :: Swiglu => gate_states. silu ( ) ,
226
- } ?;
219
+ let gate =
220
+ up_gate_states. narrow ( D :: Minus1 , self . intermediate_size , self . intermediate_size ) ?;
221
+ let gate = self . act . forward ( & gate) ?;
227
222
228
- self . down_proj . forward ( & ( gate_states * up_states) ?)
223
+ self . down_proj . forward ( & ( gate * up_states) ?)
229
224
}
230
225
}
231
226
@@ -288,22 +283,25 @@ pub struct GTEClassificationHead {
288
283
}
289
284
290
285
impl GTEClassificationHead {
291
- #[ allow( dead_code) ]
286
+ fn inner_load ( vb : VarBuilder , config : & GTEConfig ) -> Option < Linear > {
287
+ let pooler_weight = vb
288
+ . pp ( "pooler.dense" )
289
+ . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
290
+ . ok ( ) ?;
291
+ let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) . ok ( ) ?;
292
+ let pooler = Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) ;
293
+
294
+ Some ( pooler)
295
+ }
296
+
292
297
pub ( crate ) fn load ( vb : VarBuilder , config : & GTEConfig ) -> Result < Self > {
293
298
let n_classes = match & config. id2label {
294
299
None => candle:: bail!( "`id2label` must be set for classifier models" ) ,
295
300
Some ( id2label) => id2label. len ( ) ,
296
301
} ;
297
302
298
- let pooler = if let Ok ( pooler_weight) = vb
299
- . pp ( "pooler.dense" )
300
- . get ( ( config. hidden_size , config. hidden_size ) , "weight" )
301
- {
302
- let pooler_bias = vb. pp ( "pooler.dense" ) . get ( config. hidden_size , "bias" ) ?;
303
- Some ( Linear :: new ( pooler_weight, Some ( pooler_bias) , None ) )
304
- } else {
305
- None
306
- } ;
303
+ let pooler =
304
+ Self :: inner_load ( vb. pp ( "new" ) , config) . or_else ( || Self :: inner_load ( vb. clone ( ) , config) ) ;
307
305
308
306
let classifier_weight = vb
309
307
. pp ( "classifier" )
@@ -322,13 +320,15 @@ impl GTEClassificationHead {
322
320
let _enter = self . span . enter ( ) ;
323
321
324
322
let mut hidden_states = hidden_states. unsqueeze ( 1 ) ?;
323
+
325
324
if let Some ( pooler) = self . pooler . as_ref ( ) {
326
325
hidden_states = pooler. forward ( & hidden_states) ?;
327
326
hidden_states = hidden_states. tanh ( ) ?;
328
327
}
329
328
330
329
let hidden_states = self . classifier . forward ( & hidden_states) ?;
331
330
let hidden_states = hidden_states. squeeze ( 1 ) ?;
331
+
332
332
Ok ( hidden_states)
333
333
}
334
334
}
0 commit comments