1
-
2
1
from fastdup .sentry import fastdup_capture_exception
3
2
from fastdup .definitions import MISSING_LABEL
4
3
from fastdup .galleries import fastdup_imread
5
- from tqdm import tqdm
6
4
import cv2
7
5
8
- def generate_labels (filenames , kwargs ):
9
- try :
10
- from transformers import VisionEncoderDecoderModel , ViTImageProcessor , AutoTokenizer
11
- import torch
12
- except Exception as e :
13
- fastdup_capture_exception ("Auto generate labels" , e )
14
- print ("For auto captioning images need to install transforms and torch packages using `pip install transformers torch`" )
15
- return [MISSING_LABEL ]* len (filenames )
16
-
17
- try :
18
- from PIL import Image
19
- model = VisionEncoderDecoderModel .from_pretrained ("nlpconnect/vit-gpt2-image-captioning" )
20
- feature_extractor = ViTImageProcessor .from_pretrained ("nlpconnect/vit-gpt2-image-captioning" )
21
- tokenizer = AutoTokenizer .from_pretrained ("nlpconnect/vit-gpt2-image-captioning" )
22
-
23
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
24
- model .to (device )
25
- max_length = 16
26
- num_beams = 4
27
- gen_kwargs = {"max_length" : max_length , "num_beams" : num_beams }
28
-
29
- images = []
30
- for image_path in tqdm (filenames ):
31
- i_image = fastdup_imread (image_path , None , kwargs = kwargs )
32
- if i_image is not None :
33
- i_image = cv2 .cvtColor (i_image , cv2 .COLOR_BGR2RGB )
34
- im_pil = Image .fromarray (i_image )
35
- images .append (im_pil )
36
- else :
37
- images .append (None )
38
-
39
- pixel_values = feature_extractor (images = images , return_tensors = "pt" ).pixel_values
40
- pixel_values = pixel_values .to (device )
41
- output_ids = model .generate (pixel_values , ** gen_kwargs )
42
-
43
- preds = tokenizer .batch_decode (output_ids , skip_special_tokens = True )
44
- preds = [pred .strip () for pred in preds ]
45
- return preds
46
- except Exception as e :
47
- fastdup_capture_exception ("Auto caption image" , e )
48
- return [MISSING_LABEL ]* len (filenames )
49
6
50
- def generate_blip_labels (filenames , kwargs ):
7
+ def generate_labels (filenames , modelname = 'automatic' , batch_size = 8 ):
8
+ '''
9
+ This function generates captions for a given set of images, and takes the following arguments:
10
+ - filenames: the list of images passed to the function
11
+ - modelname: the captioning model to be used (default: vitgpt2)
12
+ currently available models are:
13
+ - ViT-GPT2 : 'vitgpt2'
14
+ - BLIP-2: 'blip2'
15
+ - BLIP: 'blip'
16
+ - batch_size: the size of image batches to caption (default: 8)
17
+ '''
51
18
19
+ # confirm necessary dependencies are installed, and import them
52
20
try :
53
- from transformers import BlipProcessor , BlipForConditionalGeneration
21
+ from transformers import pipeline
22
+ import torch
54
23
from PIL import Image
24
+ from tqdm import tqdm
55
25
except Exception as e :
56
26
fastdup_capture_exception ("Auto generate labels" , e )
57
- print ("For auto captioning images need to install transforms and torch packages using `pip install transformers`" )
27
+ print ("Auto captioning requires an installation of the following libraries:\n " )
28
+ print (" huggingface transformers\n pytorch\n pillow\n tqdm\n " )
29
+ print ("to install, use `pip install transformers torch pillow tqdm`" )
58
30
return [MISSING_LABEL ] * len (filenames )
59
31
60
- try :
61
- processor = BlipProcessor .from_pretrained ("Salesforce/blip-image-captioning-large" )
62
- model = BlipForConditionalGeneration .from_pretrained ("Salesforce/blip-image-captioning-large" )
63
- preds = []
64
- for image_path in tqdm (filenames ):
65
- i_image = fastdup_imread (image_path , None , kwargs = kwargs )
66
- if i_image is not None :
67
- i_image = cv2 .cvtColor (i_image , cv2 .COLOR_BGR2RGB )
68
- im_pil = Image .fromarray (i_image )
69
- inputs = processor (im_pil , return_tensors = "pt" )
70
- out = model .generate (** inputs )
71
- preds .append ((processor .decode (out [0 ], skip_special_tokens = True )))
72
- else :
73
- preds .append (MISSING_LABEL )
74
- return preds
32
+ # dictionary of captioning models
33
+ models = {
34
+ 'automatic' : "nlpconnect/vit-gpt2-image-captioning" ,
35
+ 'vitgpt2' : "nlpconnect/vit-gpt2-image-captioning" ,
36
+ 'blip2' : "Salesforce/blip2-opt-2.7b" ,
37
+ 'blip' : "Salesforce/blip-image-captioning-large"
38
+ }
75
39
76
- except Exception as e :
77
- fastdup_capture_exception ("Auto caption image blip" , e )
78
- return [MISSING_LABEL ]* len (filenames )
79
-
80
- def generate_blip2_labels (filenames , kwargs , text = None ):
81
-
82
- try :
83
- from transformers import Blip2Processor , Blip2Model
84
- from PIL import Image
85
- import torch
86
- except Exception as e :
87
- fastdup_capture_exception ("Auto generate labels" , e )
88
- print ("For auto captioning images need to install transforms and torch packages using `pip install transformers torch`" )
89
- return [MISSING_LABEL ] * len (filenames )
40
+ model = models [modelname ]
90
41
42
+ # generate captions
91
43
try :
44
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
45
+ captioner = pipeline ("image-to-text" , model = model , device = device , batch_size = batch_size )
92
46
93
- processor = Blip2Processor .from_pretrained ("Salesforce/blip2-opt-2.7b" )
94
- model = Blip2Model .from_pretrained ("Salesforce/blip2-opt-2.7b" , torch_dtype = torch .float16 )
95
- device = "cuda" if torch .cuda .is_available () else "cpu"
96
- model .to (device )
97
- preds = []
47
+ captions = []
98
48
for image_path in tqdm (filenames ):
99
- i_image = fastdup_imread (image_path , None , kwargs = kwargs )
100
- if i_image is not None :
101
- i_image = cv2 .cvtColor (i_image , cv2 .COLOR_BGR2RGB )
102
- im_pil = Image .fromarray (i_image )
103
- inputs = processor (images = im_pil , text = text , return_tensors = "pt" ).to (device , torch .float16 )
104
- generated_ids = model .generate (** inputs )
105
- generated_text = processor .batch_decode (generated_ids , skip_special_tokens = True )[0 ].strip ()
106
- preds .append (generated_text )
107
- else :
108
- preds .append (MISSING_LABEL )
109
- return preds
110
-
111
- except Exception as e :
112
- fastdup_capture_exception ("Auto caption image blip" , e )
113
- return [MISSING_LABEL ]* len (filenames )
114
-
115
-
49
+ img = Image .open (image_path )
50
+ pred = captioner (img )
51
+ caption = pred [0 ]['generated_text' ]
52
+ captions .append (caption )
53
+ return captions
116
54
117
55
56
+ except Exception as e :
57
+ fastdup_capture_exception ("Auto caption image" , e )
58
+ return [MISSING_LABEL ] * len (filenames )
118
59
119
60
120
61
def generate_vqa_labels (filenames , text , kwargs ):
62
+ # confirm necessary dependencies are installed, and import them
121
63
try :
122
64
from transformers import ViltProcessor , ViltForQuestionAnswering
65
+ import torch
123
66
from PIL import Image
67
+ from tqdm import tqdm
124
68
except Exception as e :
125
69
fastdup_capture_exception ("Auto generate labels" , e )
126
- print (
127
- "For auto captioning images need to install transforms and torch packages using `pip install transformers`" )
70
+ print ("Auto captioning requires an installation of the following libraries:\n " )
71
+ print (" huggingface transformers\n pytorch\n pillow\n tqdm\n " )
72
+ print ("to install, use `pip install transformers torch pillow tqdm`" )
128
73
return [MISSING_LABEL ] * len (filenames )
129
74
130
75
try :
@@ -150,15 +95,26 @@ def generate_vqa_labels(filenames, text, kwargs):
150
95
151
96
except Exception as e :
152
97
fastdup_capture_exception ("Auto caption image vqa" , e )
153
- return [MISSING_LABEL ]* len (filenames )
98
+ return [MISSING_LABEL ] * len (filenames )
154
99
155
100
156
101
def generate_age_labels (filenames , kwargs ):
157
- from transformers import ViTFeatureExtractor , ViTForImageClassification
158
- model = ViTForImageClassification .from_pretrained ('nateraw/vit-age-classifier' )
159
- transforms = ViTFeatureExtractor .from_pretrained ('nateraw/vit-age-classifier' )
102
+ # confirm necessary dependencies are installed, and import them
103
+ try :
104
+ from transformers import ViTFeatureExtractor , ViTForImageClassification
105
+ import torch
106
+ from PIL import Image
107
+ from tqdm import tqdm
108
+ except Exception as e :
109
+ fastdup_capture_exception ("Auto generate labels" , e )
110
+ print ("Auto captioning requires an installation of the following libraries:\n " )
111
+ print (" huggingface transformers\n pytorch\n pillow\n tqdm\n " )
112
+ print ("to install, use `pip install transformers torch pillow tqdm`" )
113
+ return [MISSING_LABEL ] * len (filenames )
160
114
161
115
try :
116
+ model = ViTForImageClassification .from_pretrained ('nateraw/vit-age-classifier' )
117
+ transforms = ViTFeatureExtractor .from_pretrained ('nateraw/vit-age-classifier' )
162
118
preds = []
163
119
# Get example image from official fairface repo + read it in as an image
164
120
for image_path in tqdm (filenames ):
@@ -174,8 +130,9 @@ def generate_age_labels(filenames, kwargs):
174
130
175
131
# Predicted Classes
176
132
pred = int (proba .argmax (1 )[0 ].int ())
177
- preds .append ( model .config .id2label [pred ])
133
+ preds .append (model .config .id2label [pred ])
178
134
return preds
179
135
except Exception as e :
180
136
fastdup_capture_exception ("Age label" , e )
181
137
return [MISSING_LABEL ] * len (filenames )
138
+
0 commit comments