Skip to content

Fix {Bert,DistilBert}SpladeHead when loading from Safetensors #564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,27 +474,38 @@ pub struct BertSpladeHead {

impl BertSpladeHead {
pub(crate) fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
let vb = vb.pp("cls.predictions");
let transform_weight = vb
.pp("transform.dense")
.pp("cls.predictions.transform.dense")
.get((config.hidden_size, config.hidden_size), "weight")?;
let transform_bias = vb.pp("transform.dense").get(config.hidden_size, "bias")?;
let transform_bias = vb
.pp("cls.predictions.transform.dense")
.get(config.hidden_size, "bias")?;
let transform = Linear::new(
transform_weight,
Some(transform_bias),
Some(config.hidden_act.clone()),
);

let transform_layer_norm = LayerNorm::load(
vb.pp("transform.LayerNorm"),
vb.pp("cls.predictions.transform.LayerNorm"),
config.hidden_size,
config.layer_norm_eps as f32,
)?;

let decoder_weight = vb
.pp("decoder")
.get((config.vocab_size, config.hidden_size), "weight")?;
let decoder_bias = vb.get(config.vocab_size, "bias")?;
// When `pytorch_model.bin` originally contains `cls.predictions.decoder.weight` but the
// tensor content shares the memory with the content on `bert.embeddings.word_embeddings.weight`,
// e.g. a subset of the original tensor, when converting the file from BIN to Safentensors
// the latter tensor that shares the memory with the previous will be removed
let decoder_weight = if vb.contains_tensor("cls.predictions.decoder.weight") {
vb.pp("cls.predictions.decoder")
.get((config.vocab_size, config.hidden_size), "weight")?
} else {
vb.pp("bert.embeddings.word_embeddings")
.get((config.vocab_size, config.hidden_size), "weight")?
};
// Same applies for the tensor `cls.predictions.decoder.bias` which is shared with
// `cls.predictions.bias` and removed in the BIN to Safentensors conversion
let decoder_bias = vb.pp("cls.predictions").get(config.vocab_size, "bias")?;
let decoder = Linear::new(decoder_weight, Some(decoder_bias), Some(HiddenAct::Relu));

Ok(Self {
Expand Down
14 changes: 11 additions & 3 deletions backends/candle/src/models/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,17 @@ impl DistilBertSpladeHead {
Some(config.activation.clone()),
);

let vocab_projector_weight = vb
.pp("vocab_projector")
.get((config.vocab_size, config.dim), "weight")?;
// When `pytorch_model.bin` originally contains `vocab_projector.weight` but the tensor
// content shares the memory with the content on `distilbert.embeddings.word_embeddings.weight`,
// e.g. a subset of the original tensor, when converting the file from BIN to Safentensors
// the latter tensor that shares the memory with the previous will be removed
let vocab_projector_weight = if vb.contains_tensor("vocab_projector.weight") {
vb.pp("vocab_projector")
.get((config.vocab_size, config.dim), "weight")?
} else {
vb.pp("distilbert.embeddings.word_embeddings")
.get((config.vocab_size, config.dim), "weight")?
};
let vocab_projector_bias = vb.pp("vocab_projector").get(config.vocab_size, "bias")?;
let vocab_projector = Linear::new(
vocab_projector_weight,
Expand Down
Loading