2
2
import random
3
3
4
4
class SmallObjectAugmentation (object ):
5
- def __init__ (self , thresh = 64 * 64 , prob = 0 , copy_times = 3 , all_objects = False , one_object = False ):
5
+ def __init__ (self , thresh = 64 * 64 , prob = 0.5 , copy_times = 3 , epochs = 30 , all_objects = False , one_object = False ):
6
6
"""
7
+ sample = {'img':img, 'annot':annots}
7
8
img = [height, width, 3]
8
9
annot = [xmin, ymin, xmax, ymax, label]
9
- thresh:小目标边界
10
-
10
+ thresh:the detection threshold of the small object. If annot_h * annot_w < thresh, the object is small
11
+ prob: the prob to do small object augmentation
12
+ epochs: the epochs to do
11
13
"""
12
14
self .thresh = thresh
13
15
self .prob = prob
14
16
self .copy_times = copy_times
17
+ self .epochs = epochs
15
18
self .all_objects = all_objects
16
19
self .one_object = one_object
17
20
if self .all_objects or self .one_object :
@@ -42,27 +45,22 @@ def donot_overlap(self, new_annot, annots):
42
45
43
46
def create_copy_annot (self , h , w , annot , annots ):
44
47
annot = annot .astype (np .int )
45
- new_annot = list ()
46
48
annot_h , annot_w = annot [3 ] - annot [1 ], annot [2 ] - annot [0 ]
47
- random_x , random_y = np .random .randint (int (annot_w / 2 ), int (w - annot_w / 2 )), \
48
- np .random .randint (int (annot_h / 2 ), int (h - annot_h / 2 ))
49
-
50
- if np .int (random_x - annot_w / 2 ) < 0 or np .floor (random_x + annot_w / 2 ) > w or \
51
- np .int (random_y - annot_h / 2 ) < 0 or np .floor (random_y + annot_h / 2 ) > h :
52
- return self .create_copy_annot (h ,w , annot , annots )
53
-
54
- xmin , ymin = random_x - annot_w / 2 , random_y - annot_h / 2
55
- xmax , ymax = xmin + annot_w , ymin + annot_h
56
- new_annot .append (xmin ), new_annot .append (ymin )
57
- new_annot .append (xmax ), new_annot .append (ymax )
58
- new_annot .append (annot [4 ])
59
-
60
- new_annot = np .array (new_annot ).astype (np .int )
61
-
62
- if self .donot_overlap (new_annot , annots ) is False :
63
- return self .create_copy_annot (h , w , annot , annots )
64
-
65
- return new_annot
49
+ for epoch in range (self .epochs ):
50
+ random_x , random_y = np .random .randint (int (annot_w / 2 ), int (w - annot_w / 2 )), \
51
+ np .random .randint (int (annot_h / 2 ), int (h - annot_h / 2 ))
52
+ xmin , ymin = random_x - annot_w / 2 , random_y - annot_h / 2
53
+ xmax , ymax = xmin + annot_w , ymin + annot_h
54
+ if np .int (xmin ) < 0 or np .floor (xmax ) > w or \
55
+ np .int (ymin ) < 0 or np .floor (ymax ) > h :
56
+ continue
57
+ new_annot = np .array ([xmin , ymin , xmax , ymax , annot [4 ]]).astype (np .int )
58
+
59
+ if self .donot_overlap (new_annot , annots ) is False :
60
+ continue
61
+
62
+ return new_annot
63
+ return None
66
64
67
65
def add_patch_in_img (self , annot , copy_annot , image ):
68
66
copy_annot = copy_annot .astype (np .int )
@@ -74,15 +72,32 @@ def __call__(self, sample):
74
72
if np .random .rand () > self .prob : return sample
75
73
76
74
img , annots = sample ['img' ], sample ['annot' ]
77
- h , w , l = img .shape [0 ], img .shape [1 ], annots .shape [0 ]
75
+ h , w = img .shape [0 ], img .shape [1 ]
76
+
77
+ small_object_list = list ()
78
+ for idx in range (annots .shape [0 ]):
79
+ annot = annots [idx ]
80
+ annot_h , annot_w = annot [3 ] - annot [1 ], annot [2 ] - annot [0 ]
81
+ if self .issmallobject (annot_h , annot_w ):
82
+ small_object_list .append (idx )
78
83
84
+ l = len (small_object_list )
85
+ # No Small Object
86
+ if l == 0 : return sample
87
+
88
+ # Refine the copy_object by the given policy
89
+ # Policy 2:
79
90
copy_object_num = np .random .randint (0 , l )
91
+ # Policy 3:
80
92
if self .all_objects :
81
93
copy_object_num = l
94
+ # Policy 1:
82
95
if self .one_object :
83
96
copy_object_num = 1
97
+
84
98
random_list = random .sample (range (l ), copy_object_num )
85
- select_annots = annots [random_list , :]
99
+ annot_idx_of_small_object = [small_object_list [idx ] for idx in random_list ]
100
+ select_annots = annots [annot_idx_of_small_object , :]
86
101
annots = annots .tolist ()
87
102
for idx in range (copy_object_num ):
88
103
annot = select_annots [idx ]
@@ -91,7 +106,7 @@ def __call__(self, sample):
91
106
if self .issmallobject (annot_h , annot_w ) is False : continue
92
107
93
108
for i in range (self .copy_times ):
94
- new_annot = self .create_copy_annot (h , w , annot , annots )
109
+ new_annot = self .create_copy_annot (h , w , annot , annots , )
95
110
if new_annot is not None :
96
111
img = self .add_patch_in_img (new_annot , annot , img )
97
112
annots .append (new_annot )
0 commit comments