Skip to content

Commit 3bee281

Browse files
Make the repo installable as a package (#26)
1 parent 578a1d3 commit 3bee281

File tree

10 files changed

+70
-5
lines changed

10 files changed

+70
-5
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
__pycache__/
2+
*.py[cod]
3+
*$py.class
4+
*.egg-info
5+
.pytest_cache
6+
.ipynb_checkpoints
7+
8+
thumbs.db
9+
.DS_Store
10+
.idea

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include clip/bpe_simple_vocab_16e6.txt.gz

clip/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .clip import *
File renamed without changes.

clip.py renamed to clip/clip.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
1010
from tqdm import tqdm
1111

12-
from model import build_model
13-
from simple_tokenizer import SimpleTokenizer as _Tokenizer
12+
from .model import build_model
13+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
1414

1515
__all__ = ["available_models", "load", "tokenize"]
1616
_tokenizer = _Tokenizer()
@@ -24,7 +24,7 @@
2424
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
2525
os.makedirs(root, exist_ok=True)
2626
filename = os.path.basename(url)
27-
27+
2828
expected_sha256 = url.split("/")[-2]
2929
download_target = os.path.join(root, filename)
3030

@@ -38,7 +38,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
3838
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
3939

4040
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
41-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
41+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
4242
while True:
4343
buffer = source.read(8192)
4444
if not buffer:
@@ -75,6 +75,8 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
7575

7676
if not jit:
7777
model = build_model(model.state_dict()).to(device)
78+
if str(device) == "cpu":
79+
model.float()
7880
return model, transform
7981

8082
# patch the device names
@@ -96,7 +98,7 @@ def patch_device(module):
9698
patch_device(model.encode_text)
9799

98100
# patch dtype to float32 on CPU
99-
if device == "cpu":
101+
if str(device) == "cpu":
100102
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
101103
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
102104
float_node = float_input.node()

model.py renamed to clip/model.py

File renamed without changes.
File renamed without changes.

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
ftfy
2+
regex
3+
tqdm
4+
torch>=1.7.1,<1.7.2
5+
torchvision==0.8.2

setup.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
3+
import pkg_resources
4+
from setuptools import setup, find_packages
5+
6+
setup(
7+
name="clip",
8+
py_modules=["clip"],
9+
version="1.0",
10+
description="",
11+
author="OpenAI",
12+
packages=find_packages(exclude=["tests*"]),
13+
install_requires=[
14+
str(r)
15+
for r in pkg_resources.parse_requirements(
16+
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
17+
)
18+
],
19+
include_package_data=True,
20+
extras_require={'dev': ['pytest']},
21+
)

tests/test_consistency.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from PIL import Image
5+
6+
import clip
7+
8+
9+
@pytest.mark.parametrize('model_name', clip.available_models())
10+
def test_consistency(model_name):
11+
device = "cpu"
12+
jit_model, transform = clip.load(model_name, device=device)
13+
py_model, _ = clip.load(model_name, device=device, jit=False)
14+
15+
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
16+
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
17+
18+
with torch.no_grad():
19+
logits_per_image, _ = jit_model(image, text)
20+
jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
21+
22+
logits_per_image, _ = py_model(image, text)
23+
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
24+
25+
assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)

0 commit comments

Comments
 (0)