@@ -474,27 +474,38 @@ pub struct BertSpladeHead {
474474
475475impl BertSpladeHead {
476476 pub ( crate ) fn load ( vb : VarBuilder , config : & BertConfig ) -> Result < Self > {
477- let vb = vb. pp ( "cls.predictions" ) ;
478477 let transform_weight = vb
479- . pp ( "transform.dense" )
478+ . pp ( "cls.predictions. transform.dense" )
480479 . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
481- let transform_bias = vb. pp ( "transform.dense" ) . get ( config. hidden_size , "bias" ) ?;
480+ let transform_bias = vb
481+ . pp ( "cls.predictions.transform.dense" )
482+ . get ( config. hidden_size , "bias" ) ?;
482483 let transform = Linear :: new (
483484 transform_weight,
484485 Some ( transform_bias) ,
485486 Some ( config. hidden_act . clone ( ) ) ,
486487 ) ;
487488
488489 let transform_layer_norm = LayerNorm :: load (
489- vb. pp ( "transform.LayerNorm" ) ,
490+ vb. pp ( "cls.predictions. transform.LayerNorm" ) ,
490491 config. hidden_size ,
491492 config. layer_norm_eps as f32 ,
492493 ) ?;
493494
494- let decoder_weight = vb
495- . pp ( "decoder" )
496- . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?;
497- let decoder_bias = vb. get ( config. vocab_size , "bias" ) ?;
495+ // When `pytorch_model.bin` originally contains `cls.predictions.decoder.weight` but the
496+ // tensor content shares the memory with the content on `bert.embeddings.word_embeddings.weight`,
497+ // e.g. a subset of the original tensor, when converting the file from BIN to Safentensors
498+ // the latter tensor that shares the memory with the previous will be removed
499+ let decoder_weight = if vb. contains_tensor ( "cls.predictions.decoder.weight" ) {
500+ vb. pp ( "cls.predictions.decoder" )
501+ . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?
502+ } else {
503+ vb. pp ( "bert.embeddings.word_embeddings" )
504+ . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?
505+ } ;
506+ // Same applies for the tensor `cls.predictions.decoder.bias` which is shared with
507+ // `cls.predictions.bias` and removed in the BIN to Safentensors conversion
508+ let decoder_bias = vb. pp ( "cls.predictions" ) . get ( config. vocab_size , "bias" ) ?;
498509 let decoder = Linear :: new ( decoder_weight, Some ( decoder_bias) , Some ( HiddenAct :: Relu ) ) ;
499510
500511 Ok ( Self {
0 commit comments