Skip to content

Commit 83ded0a

Browse files
committed
Fix data preparation, mask uncertainty value
1 parent 2168d6f commit 83ded0a

File tree

5 files changed

+15
-10
lines changed

5 files changed

+15
-10
lines changed

configs/default.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ net:
1212
pretrained: False
1313
resnet: "res53"
1414
head_in_ch: 2048
15-
num_classes: 20
15+
num_classes: 34
1616
pointhead:
17-
in_c: 532 # 512 + 20
18-
num_classes: 20
17+
in_c: 546 # 512 + num_classes
18+
num_classes: 34
1919
k: 3
2020
beta: 0.75
2121

2222
run:
2323
epochs: 200
2424

2525
apex:
26-
opt: "O0"
26+
opt: "O0"

datas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def get_voc(C, split="train"):
1111
ToTensor(),
1212
RandomCrop((256, 512)),
1313
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
14-
])
14+
])
1515
else:
1616
transforms = Compose([
1717
ToTensor(),

datas/transforms.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import random
22

3+
import numpy as np
4+
5+
import torch
36
import torch.nn.functional as F
47
from torchvision import transforms
58
from torchvision.transforms.functional import normalize
@@ -10,10 +13,10 @@ def __init__(self, shape):
1013
self.shape = [shape, shape] if isinstance(shape, int) else shape
1114

1215
def __call__(self, img, mask):
13-
img, mask = img.unsqueeze(0), mask.unsqueeze(0)
16+
img, mask = img.unsqueeze(0), mask.unsqueeze(0).float()
1417
img = F.interpolate(img, size=self.shape, mode="bilinear", align_corners=False)
1518
mask = F.interpolate(mask, size=self.shape, mode="bilinear", align_corners=False)
16-
return img[0], mask[0]
19+
return img[0], mask[0].byte()
1720

1821

1922
class RandomCrop:
@@ -54,7 +57,9 @@ def __init__(self):
5457
self.to_tensor = transforms.ToTensor()
5558

5659
def __call__(self, img, mask):
57-
return self.to_tensor(img), self.to_tensor(mask)
60+
img = self.to_tensor(img)
61+
mask = torch.from_numpy(np.array(mask))
62+
return img, mask[None]
5863

5964

6065
class Normalize:

infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@torch.no_grad()
88
def infer(loader, net, device):
99
net.eval()
10-
metric = ConfusionMatrix(21)
10+
metric = ConfusionMatrix(len(loader.dataset.classes) - 1)
1111
for i, (x, gt) in enumerate(loader):
1212
x = x.to(device, non_blocking=True)
1313
gt = gt.squeeze(1).to(device, dtype=torch.long, non_blocking=True)

model/sampling_points.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ def sampling_points(mask, N, k=3, beta=0.75, training=True):
5252
assert mask.dim() == 4, "Dim must be N(Batch)CHW"
5353
device = mask.device
5454
B, _, H, W = mask.shape
55+
mask, _ = mask.sort(1, descending=True)
5556

5657
if not training:
5758
H_step, W_step = 1 / H, 1 / W
5859
N = min(H * W, N)
59-
6060
uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
6161
_, idx = uncertainty_map.view(B, -1).topk(N, dim=1)
6262

0 commit comments

Comments
 (0)