Skip to content

Fix {Bert,DistilBert}SpladeHead when loading from Safetensors #564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025

Conversation

alvarobartt
Copy link
Member

@alvarobartt alvarobartt commented Apr 7, 2025

What does this PR do?

This PR fixes an issue preventing loading BERT and DistilBERT models with the SPLADE pooling, as when converting the pytorch_model.bin into a model.safetensors files, the tensors with shared memory for the content are removed for safety, meaning that the required weights for the SPLADE head were not there, as support for SPLADE was originally introduced for the models at https://huggingface.co/naver which are indeed pytorch_model.bin files.

So on, this PR bypasses that by adding a check on whether the required tensors are there, and if not, it falls back to the tensor with the shared memory instead.

To reproduce the issue, simply grab any model under https://huggingface.co/naver as e.g. naver/efficient-splade-V-large-query and download the pytorch_model.bin file and then convert it into a model.safetensors file with the following script:

import torch
from safetensors.torch import save_file

model_state_dict = torch.load("pytorch_model.bin", map_location=torch.device("cpu"))
contiguous_state_dict = {k: v.contiguous() for k, v in model_state_dict.items()}

save_file(contiguous_state_dict, "model.safetensors")

Then the following error will be raised:

Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'distilbert.embeddings.word_embeddings.weight', 'vocab_projector.weight'}].
A potential way to correctly save your model is to use `save_model`.
More information at https://huggingface.co/docs/safetensors/torch_shared_tensors

And indeed if we inspect the model.safetensors metadata, we'll see that the latter tensor that shares memory with a previous tensor won't be there.

Fixes #548

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@Narsil or @McPatate

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Those many locations loads are getting quite annoying, at some point we should figure out a nicer way to abstract those.

But this looks good.

@Narsil Narsil merged commit 3c50308 into main Apr 8, 2025
14 checks passed
@Narsil Narsil deleted the patch-splade-from-safetensors branch April 8, 2025 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

cannot find tensor cls.predictions.decoder.weight
2 participants