Skip to content

Commit b1fd8b7

Browse files
committed
refine
1 parent 6d8cdf8 commit b1fd8b7

File tree

7 files changed

+26
-55
lines changed

7 files changed

+26
-55
lines changed

Augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class AutoAugmenter(object):
4545
Returns:
4646
A tuple containing the augmented versions of `image` and `bboxes`.
4747
"""
48-
def __init__(self, augmentation_name='v2'):
48+
def __init__(self, augmentation_name='v4'):
4949
self.augmentation_name = augmentation_name
5050

5151
def normalizer(self, image, annots):

augmentation_zoo/Myautoaugment_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22
import inspect
33
from PIL import Image, ImageOps, ImageEnhance
4-
import cv2
5-
import matplotlib.pyplot as plt
64
import math
75
import random
86

@@ -135,6 +133,16 @@ def policy_v3():
135133
]
136134
return policy
137135

136+
def policy_v4():
137+
policy = [
138+
[('TranslateX_BBox', 0.6, 4), ('Equalize', 0.8, 10)],
139+
[('TranslateY_Only_BBoxes', 0.2, 2), ('Cutout', 0.8, 8)],
140+
[('ShearY_BBox', 1.0, 2), ('TranslateY_Only_BBoxes', 0.6, 6)],
141+
[('Rotate_BBox', 0.6, 10), ('Color', 1.0, 6)],
142+
[]
143+
]
144+
return policy
145+
138146
"""
139147
AutoContrast
140148
Equalize
@@ -164,10 +172,11 @@ def policy_v3():
164172
Cutout_Only_BBoxes
165173
"""
166174

175+
167176
def policy_vtest():
168177
""" (policy, pro, level)"""
169178
policy = [
170-
[('TranslateX_BBox', 1, 10), ('BBox_Cutout', 1, 10)],
179+
[('Flip', 1, 10)],
171180
]
172181
return policy
173182

@@ -1140,7 +1149,7 @@ def cutout_only_bboxes(image, bboxes, prob, pad_size, replace):
11401149
return _apply_multi_bbox_augmentation_wrapper(
11411150
image, bboxes, prob, cutout, func_changes_bbox, pad_size, replace)
11421151

1143-
def flip(image, bboxes, prob, level):
1152+
def flip(image, bboxes, prob):
11441153
if np.random.rand() < prob:
11451154
image = image[:, ::-1, :]
11461155

@@ -1397,7 +1406,7 @@ def distort_image_with_autoaugment(image, bboxes, augmentation_name):
13971406
添加新算法,除函数本身定义外,需在NAME_TO_FUNC、level_to_arg中对应添加。
13981407
"""
13991408
available_policies = {'v0': policy_v0, 'v1': policy_v1, 'v2': policy_v2,
1400-
'v3': policy_v3, 'test': policy_vtest,
1409+
'v3': policy_v3, 'v4': policy_v4, 'test': policy_vtest,
14011410
'custom': policy_custom}
14021411
if augmentation_name not in available_policies:
14031412
raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name))

augmentation_zoo/mixup.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
""" AUTOAUGMENT """
2323
AUTOAUGMENT = False
24-
AUTO_POLICY = 'v1'
24+
AUTO_POLICY = 'v4'
2525

2626
""" GRIDMASK """
2727
GRID = False
@@ -38,7 +38,7 @@
3838
SOA_COPY_TIMES = 3
3939
SOA_EPOCHS = 30
4040
SOA_ONE_OBJECT = False
41-
SOA_ALL_OBJECTS = True
41+
SOA_ALL_OBJECTS = False
4242

4343
""" MIXUP """
4444
MIXUP = False

eval_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from retinanet import model
33
from prepare_data import VocDataset, Normalizer, Resizer, AspectRatioBasedSampler, UnNormalizer, collater
4-
from Augmentation import autoaugmenter
4+
from Augmentation import AutoAugmenter
55
from torchvision import transforms
66
import config
77
from retinanet import csv_eval

test_augmentation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def _make_transform():
2626
voc_train = VocDataset(VOC_ROOT_DIR, 'train', transform=transforms.Compose(transform_list))
2727
kitti_train = KittiDataset(KITTI_ROOT_DIR, 'train', transforms.Compose(transform_list))
2828

29-
for i in range(voc_train.__len__()):
30-
print(i)
31-
sample = voc_train[i]
32-
# easy_visualization(sample)
29+
# for i in range(voc_train.__len__()):
30+
# print(i)
31+
# sample = voc_train[i]
32+
# easy_visualization(sample)
3333

34-
# sample = voc_train[19]
35-
# easy_visualization(sample)
34+
sample = voc_train[5]
35+
easy_visualization(sample)
3636

3737
# for i in range(kitti_train.__len__()):
3838
# sample = kitti_train[i]

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def main():
135135
retinanet.eval()
136136

137137
print('\nBest_mAP:', BEST_MAP_EPOCH)
138-
for label in range(retinanet.num_classes()):
139-
label_name = retinanet.label_to_name(label)
138+
for label in range(dataset_val.num_classes()):
139+
label_name = dataset_val.label_to_name(label)
140140
print('{}: {}'.format(label_name, best_average_precisions[label][0]))
141141
print('BEST MAP: ', BEST_MAP)
142142
# torch.save(retinanet, 'model_final.pt')

0 commit comments

Comments
 (0)