Skip to content

Commit 3e25f5d

Browse files
committed
refine small_object_augmentation
1 parent 61158a1 commit 3e25f5d

File tree

5 files changed

+46
-29
lines changed

5 files changed

+46
-29
lines changed

augmentation_zoo/MyGridMask.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob
1212
self.mode = mode
1313
self.st_prob = prob
1414
self.prob = prob
15+
1516
def set_prob(self, epoch, max_epoch):
1617
self.prob = self.st_prob * epoch / max_epoch
1718

augmentation_zoo/SmallObjectAugmentation.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
import random
33

44
class SmallObjectAugmentation(object):
5-
def __init__(self, thresh=64*64, prob=0, copy_times=3, all_objects=False, one_object=False):
5+
def __init__(self, thresh=64*64, prob=0.5, copy_times=3, epochs=30, all_objects=False, one_object=False):
66
"""
7+
sample = {'img':img, 'annot':annots}
78
img = [height, width, 3]
89
annot = [xmin, ymin, xmax, ymax, label]
9-
thresh:小目标边界
10-
10+
thresh:the detection threshold of the small object. If annot_h * annot_w < thresh, the object is small
11+
prob: the prob to do small object augmentation
12+
epochs: the epochs to do
1113
"""
1214
self.thresh = thresh
1315
self.prob = prob
1416
self.copy_times = copy_times
17+
self.epochs = epochs
1518
self.all_objects = all_objects
1619
self.one_object = one_object
1720
if self.all_objects or self.one_object:
@@ -42,27 +45,22 @@ def donot_overlap(self, new_annot, annots):
4245

4346
def create_copy_annot(self, h, w, annot, annots):
4447
annot = annot.astype(np.int)
45-
new_annot = list()
4648
annot_h, annot_w = annot[3] - annot[1], annot[2] - annot[0]
47-
random_x, random_y = np.random.randint(int(annot_w/2), int(w-annot_w/2)), \
48-
np.random.randint(int(annot_h/2), int(h-annot_h/2))
49-
50-
if np.int(random_x - annot_w/2) < 0 or np.floor(random_x + annot_w/2) > w or \
51-
np.int(random_y - annot_h/2) < 0 or np.floor(random_y + annot_h/2) > h:
52-
return self.create_copy_annot(h ,w, annot, annots)
53-
54-
xmin, ymin = random_x - annot_w/2, random_y - annot_h/2
55-
xmax, ymax = xmin + annot_w, ymin + annot_h
56-
new_annot.append(xmin), new_annot.append(ymin)
57-
new_annot.append(xmax), new_annot.append(ymax)
58-
new_annot.append(annot[4])
59-
60-
new_annot = np.array(new_annot).astype(np.int)
61-
62-
if self.donot_overlap(new_annot, annots) is False:
63-
return self.create_copy_annot(h, w, annot, annots)
64-
65-
return new_annot
49+
for epoch in range(self.epochs):
50+
random_x, random_y = np.random.randint(int(annot_w / 2), int(w - annot_w / 2)), \
51+
np.random.randint(int(annot_h / 2), int(h - annot_h / 2))
52+
xmin, ymin = random_x - annot_w / 2, random_y - annot_h / 2
53+
xmax, ymax = xmin + annot_w, ymin + annot_h
54+
if np.int(xmin) < 0 or np.floor(xmax) > w or \
55+
np.int(ymin) < 0 or np.floor(ymax) > h:
56+
continue
57+
new_annot = np.array([xmin, ymin, xmax, ymax, annot[4]]).astype(np.int)
58+
59+
if self.donot_overlap(new_annot, annots) is False:
60+
continue
61+
62+
return new_annot
63+
return None
6664

6765
def add_patch_in_img(self, annot, copy_annot, image):
6866
copy_annot = copy_annot.astype(np.int)
@@ -74,15 +72,32 @@ def __call__(self, sample):
7472
if np.random.rand() > self.prob: return sample
7573

7674
img, annots = sample['img'], sample['annot']
77-
h, w, l = img.shape[0], img.shape[1], annots.shape[0]
75+
h, w= img.shape[0], img.shape[1]
76+
77+
small_object_list = list()
78+
for idx in range(annots.shape[0]):
79+
annot = annots[idx]
80+
annot_h, annot_w = annot[3] - annot[1], annot[2] - annot[0]
81+
if self.issmallobject(annot_h, annot_w):
82+
small_object_list.append(idx)
7883

84+
l = len(small_object_list)
85+
# No Small Object
86+
if l == 0: return sample
87+
88+
# Refine the copy_object by the given policy
89+
# Policy 2:
7990
copy_object_num = np.random.randint(0, l)
91+
# Policy 3:
8092
if self.all_objects:
8193
copy_object_num = l
94+
# Policy 1:
8295
if self.one_object:
8396
copy_object_num = 1
97+
8498
random_list = random.sample(range(l), copy_object_num)
85-
select_annots = annots[random_list, :]
99+
annot_idx_of_small_object = [small_object_list[idx] for idx in random_list]
100+
select_annots = annots[annot_idx_of_small_object, :]
86101
annots = annots.tolist()
87102
for idx in range(copy_object_num):
88103
annot = select_annots[idx]
@@ -91,7 +106,7 @@ def __call__(self, sample):
91106
if self.issmallobject(annot_h, annot_w) is False: continue
92107

93108
for i in range(self.copy_times):
94-
new_annot = self.create_copy_annot(h, w, annot, annots)
109+
new_annot = self.create_copy_annot(h, w, annot, annots,)
95110
if new_annot is not None:
96111
img = self.add_patch_in_img(new_annot, annot, img)
97112
annots.append(new_annot)

config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
SOA_THRESH = 64*64
3737
SOA_PROB = 1
3838
SOA_COPY_TIMES = 3
39+
SOA_EPOCHS = 30
3940
SOA_ONE_OBJECT = False
40-
SOA_ALL_OBJECTS = False
41+
SOA_ALL_OBJECTS = True
4142

4243
""" MIXUP """
4344
MIXUP = False

test_augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _make_transform():
1818
if cfg.RANDOM_FLIP:
1919
transform_list.append(RandomFlip())
2020
if cfg.SMALL_OBJECT_AUGMENTATION:
21-
transform_list.append(SmallObjectAugmentation(cfg.SOA_THRESH, cfg.SOA_PROB, cfg.SOA_COPY_TIMES, cfg.SOA_ALL_OBJECTS, cfg.SOA_ONE_OBJECT))
21+
transform_list.append(SmallObjectAugmentation(cfg.SOA_THRESH, cfg.SOA_PROB, cfg.SOA_COPY_TIMES, cfg.SOA_EPOCHS, cfg.SOA_ALL_OBJECTS, cfg.SOA_ONE_OBJECT))
2222
return transform_list
2323

2424
if __name__ == '__main__':

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _make_transform():
3131
if cfg.RANDOM_FLIP:
3232
transform_list.append(RandomFlip())
3333
if cfg.SMALL_OBJECT_AUGMENTATION:
34-
transform_list.append(SmallObjectAugmentation(cfg.SOA_THRESH, cfg.SOA_PROB, cfg.SOA_COPY_TIMES, cfg.SOA_ALL_OBJECTS, cfg.SOA_ONE_OBJECT))
34+
transform_list.append(SmallObjectAugmentation(cfg.SOA_THRESH, cfg.SOA_PROB, cfg.SOA_COPY_TIMES, cfg.SOA_EPOCHS, cfg.SOA_ALL_OBJECTS, cfg.SOA_ONE_OBJECT))
3535
transform_list.append(Normalizer())
3636
transform_list.append(Resizer())
3737
return transform_list

0 commit comments

Comments
 (0)