Skip to content

Commit d620db4

Browse files
committed
adding
1 parent 7d919bf commit d620db4

File tree

5 files changed

+252
-116
lines changed

5 files changed

+252
-116
lines changed

fastdup/captions.py

Lines changed: 61 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,75 @@
1-
21
from fastdup.sentry import fastdup_capture_exception
32
from fastdup.definitions import MISSING_LABEL
43
from fastdup.galleries import fastdup_imread
5-
from tqdm import tqdm
64
import cv2
75

8-
def generate_labels(filenames, kwargs):
9-
try:
10-
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
11-
import torch
12-
except Exception as e:
13-
fastdup_capture_exception("Auto generate labels", e)
14-
print("For auto captioning images need to install transforms and torch packages using `pip install transformers torch`")
15-
return [MISSING_LABEL]*len(filenames)
16-
17-
try:
18-
from PIL import Image
19-
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
20-
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
21-
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
22-
23-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24-
model.to(device)
25-
max_length = 16
26-
num_beams = 4
27-
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
28-
29-
images = []
30-
for image_path in tqdm(filenames):
31-
i_image = fastdup_imread(image_path, None, kwargs=kwargs)
32-
if i_image is not None:
33-
i_image = cv2.cvtColor(i_image, cv2.COLOR_BGR2RGB)
34-
im_pil = Image.fromarray(i_image)
35-
images.append(im_pil)
36-
else:
37-
images.append(None)
38-
39-
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
40-
pixel_values = pixel_values.to(device)
41-
output_ids = model.generate(pixel_values, **gen_kwargs)
42-
43-
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
44-
preds = [pred.strip() for pred in preds]
45-
return preds
46-
except Exception as e:
47-
fastdup_capture_exception("Auto caption image", e)
48-
return [MISSING_LABEL]*len(filenames)
496

50-
def generate_blip_labels(filenames, kwargs):
7+
def generate_labels(filenames, modelname='automatic', batch_size=8):
8+
'''
9+
This function generates captions for a given set of images, and takes the following arguments:
10+
- filenames: the list of images passed to the function
11+
- modelname: the captioning model to be used (default: vitgpt2)
12+
currently available models are:
13+
- ViT-GPT2 : 'vitgpt2'
14+
- BLIP-2: 'blip2'
15+
- BLIP: 'blip'
16+
- batch_size: the size of image batches to caption (default: 8)
17+
'''
5118

19+
# confirm necessary dependencies are installed, and import them
5220
try:
53-
from transformers import BlipProcessor, BlipForConditionalGeneration
21+
from transformers import pipeline
22+
import torch
5423
from PIL import Image
24+
from tqdm import tqdm
5525
except Exception as e:
5626
fastdup_capture_exception("Auto generate labels", e)
57-
print("For auto captioning images need to install transforms and torch packages using `pip install transformers`")
27+
print("Auto captioning requires an installation of the following libraries:\n")
28+
print(" huggingface transformers\n pytorch\n pillow\n tqdm\n")
29+
print("to install, use `pip install transformers torch pillow tqdm`")
5830
return [MISSING_LABEL] * len(filenames)
5931

60-
try:
61-
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
62-
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
63-
preds = []
64-
for image_path in tqdm(filenames):
65-
i_image = fastdup_imread(image_path, None, kwargs=kwargs)
66-
if i_image is not None:
67-
i_image = cv2.cvtColor(i_image, cv2.COLOR_BGR2RGB)
68-
im_pil = Image.fromarray(i_image)
69-
inputs = processor(im_pil, return_tensors="pt")
70-
out = model.generate(**inputs)
71-
preds.append((processor.decode(out[0], skip_special_tokens=True)))
72-
else:
73-
preds.append(MISSING_LABEL)
74-
return preds
32+
# dictionary of captioning models
33+
models = {
34+
'automatic': "nlpconnect/vit-gpt2-image-captioning",
35+
'vitgpt2': "nlpconnect/vit-gpt2-image-captioning",
36+
'blip2': "Salesforce/blip2-opt-2.7b",
37+
'blip': "Salesforce/blip-image-captioning-large"
38+
}
7539

76-
except Exception as e:
77-
fastdup_capture_exception("Auto caption image blip", e)
78-
return [MISSING_LABEL]*len(filenames)
79-
80-
def generate_blip2_labels(filenames, kwargs, text=None):
81-
82-
try:
83-
from transformers import Blip2Processor, Blip2Model
84-
from PIL import Image
85-
import torch
86-
except Exception as e:
87-
fastdup_capture_exception("Auto generate labels", e)
88-
print("For auto captioning images need to install transforms and torch packages using `pip install transformers torch`")
89-
return [MISSING_LABEL] * len(filenames)
40+
model = models[modelname]
9041

42+
# generate captions
9143
try:
44+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45+
captioner = pipeline("image-to-text", model=model, device=device, batch_size=batch_size)
9246

93-
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
94-
model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
95-
device = "cuda" if torch.cuda.is_available() else "cpu"
96-
model.to(device)
97-
preds = []
47+
captions = []
9848
for image_path in tqdm(filenames):
99-
i_image = fastdup_imread(image_path, None, kwargs=kwargs)
100-
if i_image is not None:
101-
i_image = cv2.cvtColor(i_image, cv2.COLOR_BGR2RGB)
102-
im_pil = Image.fromarray(i_image)
103-
inputs = processor(images=im_pil, text=text, return_tensors="pt").to(device, torch.float16)
104-
generated_ids = model.generate(**inputs)
105-
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
106-
preds.append(generated_text)
107-
else:
108-
preds.append(MISSING_LABEL)
109-
return preds
110-
111-
except Exception as e:
112-
fastdup_capture_exception("Auto caption image blip", e)
113-
return [MISSING_LABEL]*len(filenames)
114-
115-
49+
img = Image.open(image_path)
50+
pred = captioner(img)
51+
caption = pred[0]['generated_text']
52+
captions.append(caption)
53+
return captions
11654

11755

56+
except Exception as e:
57+
fastdup_capture_exception("Auto caption image", e)
58+
return [MISSING_LABEL] * len(filenames)
11859

11960

12061
def generate_vqa_labels(filenames, text, kwargs):
62+
# confirm necessary dependencies are installed, and import them
12163
try:
12264
from transformers import ViltProcessor, ViltForQuestionAnswering
65+
import torch
12366
from PIL import Image
67+
from tqdm import tqdm
12468
except Exception as e:
12569
fastdup_capture_exception("Auto generate labels", e)
126-
print(
127-
"For auto captioning images need to install transforms and torch packages using `pip install transformers`")
70+
print("Auto captioning requires an installation of the following libraries:\n")
71+
print(" huggingface transformers\n pytorch\n pillow\n tqdm\n")
72+
print("to install, use `pip install transformers torch pillow tqdm`")
12873
return [MISSING_LABEL] * len(filenames)
12974

13075
try:
@@ -150,15 +95,26 @@ def generate_vqa_labels(filenames, text, kwargs):
15095

15196
except Exception as e:
15297
fastdup_capture_exception("Auto caption image vqa", e)
153-
return [MISSING_LABEL]*len(filenames)
98+
return [MISSING_LABEL] * len(filenames)
15499

155100

156101
def generate_age_labels(filenames, kwargs):
157-
from transformers import ViTFeatureExtractor, ViTForImageClassification
158-
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
159-
transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
102+
# confirm necessary dependencies are installed, and import them
103+
try:
104+
from transformers import ViTFeatureExtractor, ViTForImageClassification
105+
import torch
106+
from PIL import Image
107+
from tqdm import tqdm
108+
except Exception as e:
109+
fastdup_capture_exception("Auto generate labels", e)
110+
print("Auto captioning requires an installation of the following libraries:\n")
111+
print(" huggingface transformers\n pytorch\n pillow\n tqdm\n")
112+
print("to install, use `pip install transformers torch pillow tqdm`")
113+
return [MISSING_LABEL] * len(filenames)
160114

161115
try:
116+
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
117+
transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
162118
preds = []
163119
# Get example image from official fairface repo + read it in as an image
164120
for image_path in tqdm(filenames):
@@ -174,8 +130,9 @@ def generate_age_labels(filenames, kwargs):
174130

175131
# Predicted Classes
176132
pred = int(proba.argmax(1)[0].int())
177-
preds.append( model.config.id2label[pred])
133+
preds.append(model.config.id2label[pred])
178134
return preds
179135
except Exception as e:
180136
fastdup_capture_exception("Age label", e)
181137
return [MISSING_LABEL] * len(filenames)
138+

0 commit comments

Comments
 (0)