Skip to content

Commit 4d120f3

Browse files
authored
Add PyTorch Hub configuration file (#259)
1 parent b46f5ac commit 4d120f3

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

hubconf.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models
2+
import re
3+
import string
4+
5+
dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
6+
7+
# For compatibility (cannot include special characters in function name)
8+
model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
9+
10+
def _create_hub_entrypoint(model):
11+
def entrypoint(**kwargs):
12+
return _load(model, **kwargs)
13+
14+
entrypoint.__doc__ = f"""Loads the {model} CLIP model
15+
16+
Parameters
17+
----------
18+
device : Union[str, torch.device]
19+
The device to put the loaded model
20+
21+
jit : bool
22+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
23+
24+
download_root: str
25+
path to download the model files; by default, it uses "~/.cache/clip"
26+
27+
Returns
28+
-------
29+
model : torch.nn.Module
30+
The {model} CLIP model
31+
32+
preprocess : Callable[[PIL.Image], torch.Tensor]
33+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
34+
"""
35+
return entrypoint
36+
37+
def tokenize():
38+
return _tokenize
39+
40+
_entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
41+
42+
globals().update(_entrypoints)

0 commit comments

Comments
 (0)