Fix {Bert,DistilBert}SpladeHead
when loading from Safetensors
#564
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR fixes an issue preventing loading BERT and DistilBERT models with the SPLADE pooling, as when converting the
pytorch_model.bin
into amodel.safetensors
files, the tensors with shared memory for the content are removed for safety, meaning that the required weights for the SPLADE head were not there, as support for SPLADE was originally introduced for the models at https://huggingface.co/naver which are indeedpytorch_model.bin
files.So on, this PR bypasses that by adding a check on whether the required tensors are there, and if not, it falls back to the tensor with the shared memory instead.
To reproduce the issue, simply grab any model under https://huggingface.co/naver as e.g.
naver/efficient-splade-V-large-query
and download thepytorch_model.bin
file and then convert it into amodel.safetensors
file with the following script:Then the following error will be raised:
And indeed if we inspect the
model.safetensors
metadata, we'll see that the latter tensor that shares memory with a previous tensor won't be there.Fixes #548
Before submitting
Who can review?
@Narsil or @McPatate