diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index 32880d44..1720ce9d 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -474,11 +474,12 @@ pub struct BertSpladeHead { impl BertSpladeHead { pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result { - let vb = vb.pp("cls.predictions"); let transform_weight = vb - .pp("transform.dense") + .pp("cls.predictions.transform.dense") .get((config.hidden_size, config.hidden_size), "weight")?; - let transform_bias = vb.pp("transform.dense").get(config.hidden_size, "bias")?; + let transform_bias = vb + .pp("cls.predictions.transform.dense") + .get(config.hidden_size, "bias")?; let transform = Linear::new( transform_weight, Some(transform_bias), @@ -486,15 +487,25 @@ impl BertSpladeHead { ); let transform_layer_norm = LayerNorm::load( - vb.pp("transform.LayerNorm"), + vb.pp("cls.predictions.transform.LayerNorm"), config.hidden_size, config.layer_norm_eps as f32, )?; - let decoder_weight = vb - .pp("decoder") - .get((config.vocab_size, config.hidden_size), "weight")?; - let decoder_bias = vb.get(config.vocab_size, "bias")?; + // When `pytorch_model.bin` originally contains `cls.predictions.decoder.weight` but the + // tensor content shares the memory with the content on `bert.embeddings.word_embeddings.weight`, + // e.g. a subset of the original tensor, when converting the file from BIN to Safentensors + // the latter tensor that shares the memory with the previous will be removed + let decoder_weight = if vb.contains_tensor("cls.predictions.decoder.weight") { + vb.pp("cls.predictions.decoder") + .get((config.vocab_size, config.hidden_size), "weight")? + } else { + vb.pp("bert.embeddings.word_embeddings") + .get((config.vocab_size, config.hidden_size), "weight")? + }; + // Same applies for the tensor `cls.predictions.decoder.bias` which is shared with + // `cls.predictions.bias` and removed in the BIN to Safentensors conversion + let decoder_bias = vb.pp("cls.predictions").get(config.vocab_size, "bias")?; let decoder = Linear::new(decoder_weight, Some(decoder_bias), Some(HiddenAct::Relu)); Ok(Self { diff --git a/backends/candle/src/models/distilbert.rs b/backends/candle/src/models/distilbert.rs index 7b39f7b9..b7b43893 100644 --- a/backends/candle/src/models/distilbert.rs +++ b/backends/candle/src/models/distilbert.rs @@ -391,9 +391,17 @@ impl DistilBertSpladeHead { Some(config.activation.clone()), ); - let vocab_projector_weight = vb - .pp("vocab_projector") - .get((config.vocab_size, config.dim), "weight")?; + // When `pytorch_model.bin` originally contains `vocab_projector.weight` but the tensor + // content shares the memory with the content on `distilbert.embeddings.word_embeddings.weight`, + // e.g. a subset of the original tensor, when converting the file from BIN to Safentensors + // the latter tensor that shares the memory with the previous will be removed + let vocab_projector_weight = if vb.contains_tensor("vocab_projector.weight") { + vb.pp("vocab_projector") + .get((config.vocab_size, config.dim), "weight")? + } else { + vb.pp("distilbert.embeddings.word_embeddings") + .get((config.vocab_size, config.dim), "weight")? + }; let vocab_projector_bias = vb.pp("vocab_projector").get(config.vocab_size, "bias")?; let vocab_projector = Linear::new( vocab_projector_weight,