Skip to content

Commit 32f180f

Browse files
committed
* Update Cityscapes
- 19 classes check - Transform method change * Add ignore index on Cross Entropy * Learning Rate 10x for ASPP * SGD optimizer * Faster metric calculate
1 parent 83ded0a commit 32f180f

File tree

12 files changed

+229
-59
lines changed

12 files changed

+229
-59
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
cityscapes
55
cityscapes/*
66

7-
voc
8-
voc/*
7+
pascalvoc
8+
pascalvoc/*
99

1010
outs
1111
outs/*

configs/default.yaml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,28 @@ data:
55
mode: "fine"
66
target_type: "semantic"
77
loader:
8-
batch_size: 2
8+
batch_size: 6
9+
num_workers: 8
910

1011
net:
1112
deeplab:
1213
pretrained: False
13-
resnet: "res53"
14+
resnet: "res101"
1415
head_in_ch: 2048
15-
num_classes: 34
16+
num_classes: 19
1617
pointhead:
17-
in_c: 546 # 512 + num_classes
18-
num_classes: 34
18+
in_c: 531 # 512 + num_classes
19+
num_classes: 19
1920
k: 3
2021
beta: 0.75
2122

2223
run:
2324
epochs: 200
2425

26+
train:
27+
lr: 0.01
28+
momentum: 0.9
29+
weight_decay: 0.0001
30+
2531
apex:
2632
opt: "O0"

configs/parser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ def __init__(self, path, args=None):
5656
raise NotImplementedError("Don't use args")
5757
assert isinstance(args, argparse.Namespace), "Check args"
5858

59-
path = f"{os.getcwd()}/{path}"
60-
default_path = f"{os.getcwd()}/configs/default.yaml"
59+
full_path = f"{os.getcwd()}/{path}"
60+
default_path = full_path.replace(path, "configs/default.yaml")
6161
self.init_yaml(default_path)
62-
self.update_yaml(path)
62+
self.update_yaml(full_path)
6363

6464
def init_yaml(self, path):
6565
with open(path, 'r') as f:

datas/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@
22
from torchvision.datasets.voc import VOCSegmentation
33
from torchvision.datasets.cityscapes import Cityscapes
44

5-
from .transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomFlip
5+
from .transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomFlip, ConvertMaskID
66

77

88
def get_voc(C, split="train"):
99
if split == "train":
10-
Compose([
10+
transforms = Compose([
1111
ToTensor(),
12-
RandomCrop((256, 512)),
12+
RandomCrop((256, 256)),
13+
Resize((256, 256)),
1314
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
14-
])
15+
])
1516
else:
1617
transforms = Compose([
1718
ToTensor(),
18-
Resize((256, 512)),
19+
Resize((256, 256)),
1920
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
2021
])
2122

22-
return VOCSegmentation(**C, image_set=split, transforms=transforms)
23+
return VOCSegmentation(C['root'], download=True, image_set=split, transforms=transforms)
2324

2425

2526
def get_cityscapes(C, split="train"):
@@ -28,13 +29,15 @@ def get_cityscapes(C, split="train"):
2829
transforms = Compose([
2930
ToTensor(),
3031
RandomCrop(768),
32+
ConvertMaskID(Cityscapes.classes),
3133
RandomFlip(),
3234
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
3335
])
3436
else:
3537
transforms = Compose([
3638
ToTensor(),
3739
Resize(768),
40+
ConvertMaskID(Cityscapes.classes),
3841
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
3942
])
4043
return Cityscapes(**C, split=split, transforms=transforms)

datas/transforms.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,21 @@ def __init__(self, shape):
1515
def __call__(self, img, mask):
1616
img, mask = img.unsqueeze(0), mask.unsqueeze(0).float()
1717
img = F.interpolate(img, size=self.shape, mode="bilinear", align_corners=False)
18-
mask = F.interpolate(mask, size=self.shape, mode="bilinear", align_corners=False)
18+
mask = F.interpolate(mask, size=self.shape, mode="nearest")
1919
return img[0], mask[0].byte()
2020

2121

2222
class RandomCrop:
2323
def __init__(self, shape):
2424
self.shape = [shape, shape] if isinstance(shape, int) else shape
25+
self.fill = 0
26+
self.padding_mode = 'constant'
2527

2628
def _get_range(self, shape, crop_shape):
27-
start = random.randint(0, shape - crop_shape)
29+
if shape == crop_shape:
30+
start = 0
31+
else:
32+
start = random.randint(0, shape - crop_shape)
2833
end = start + crop_shape
2934
return start, end
3035

@@ -79,3 +84,20 @@ def __call__(self, img, mask):
7984
for t in self.transforms:
8085
img, mask = t(img, mask)
8186
return img, mask
87+
88+
89+
class ConvertMaskID:
90+
"""
91+
Convert 34 classes to 19 classes
92+
93+
Change the `id` value of CityscapesClass to `train_id`
94+
"""
95+
def __init__(self, classes):
96+
self.classes = classes
97+
98+
def __call__(self, img, mask):
99+
mask_train_id = mask.clone()
100+
for c in self.classes:
101+
mask_train_id[mask == c.id] = c.train_id
102+
103+
return img, mask_train_id

infer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import torch
1+
import time
22
import logging
33

4+
import torch
5+
46
from utils.metrics import ConfusionMatrix
57

68

79
@torch.no_grad()
810
def infer(loader, net, device):
911
net.eval()
10-
metric = ConfusionMatrix(len(loader.dataset.classes) - 1)
12+
num_classes = 19 # Hard coding for Cityscapes
13+
metric = ConfusionMatrix(num_classes)
1114
for i, (x, gt) in enumerate(loader):
1215
x = x.to(device, non_blocking=True)
1316
gt = gt.squeeze(1).to(device, dtype=torch.long, non_blocking=True)

main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def set_loggging(save_dir):
7575
PointHead(**C.net.pointhead)
7676
).to(device)
7777

78-
optim = torch.optim.AdamW(net.parameters())
78+
params = [{"params": net.backbone.backbone.parameters(), "lr": C.train.lr},
79+
{"params": net.head.parameters(), "lr": C.train.lr},
80+
{"params": net.backbone.classifier.parameters(), "lr": C.train.lr * 10}]
81+
82+
optim = torch.optim.SGD(params, momentum=C.train.momentum, weight_decay=C.train.weight_decay)
7983

8084
net, optim = amp.initialize(net, optim, opt_level=C.apex.opt)
8185
if args.distributed:

model/deeplab.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
77
from torchvision.models.segmentation.fcn import FCNHead
88
from .resnet import resnet103, resnet53
9+
from torchvision.models import resnet50, resnet101
910

1011

1112
class SmallDeepLab(_SimpleSegmentationModel):
@@ -17,8 +18,10 @@ def forward(self, input_):
1718

1819
def deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21):
1920
resnet = {
20-
"res53": resnet53,
21-
"res103": resnet103
21+
"res53": resnet53,
22+
"res103": resnet103,
23+
"res50": resnet50,
24+
"res101": resnet101
2225
}[resnet]
2326

2427
net = SmallDeepLab(
@@ -28,9 +31,6 @@ def deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21
2831
),
2932
classifier=DeepLabHead(head_in_ch, num_classes)
3033
)
31-
if pretrained:
32-
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', progress=True)
33-
net.load_state_dict(state_dict)
3434
return net
3535

3636

model/pointrend.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ def forward(self, x, res2, out):
3737

3838
@torch.no_grad()
3939
def inference(self, x, res2, out):
40-
stride = x.shape[-1] // out.shape[-1]
41-
num_points = x.shape[-1] // stride
40+
"""
41+
During inference, subdivision uses N=8096
42+
(i.e., the number of points in the stride 16 map of a 1024×2048 image)
43+
"""
44+
num_points = 8096
4245

4346
while out.shape[-1] != x.shape[-1]:
4447
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)
@@ -54,9 +57,9 @@ def inference(self, x, res2, out):
5457

5558
B, C, H, W = out.shape
5659
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
57-
out = out.reshape(B, C, -1)
58-
out = out.scatter_(2, points_idx, rend)
59-
out = out.view(B, C, H, W)
60+
out = (out.reshape(B, C, -1)
61+
.scatter_(2, points_idx, rend)
62+
.view(B, C, H, W))
6063

6164
return {"fine": out}
6265

tests/test_cityscapes.ipynb

Lines changed: 142 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)