Skip to content

Commit 186a1fc

Browse files
[Feature] Support multiple losses during training (open-mmlab#818)
* multiple losses * fix lint error * fix typos * fix typos * Adding Attribute * Fixing loss_ prefix * Fixing loss_ prefix * Fixing loss_ prefix * Add Same * loss_name must has 'loss_' prefix * Fix unittest * Fix unittest * Fix unittest * Update mmseg/models/decode_heads/decode_head.py Co-authored-by: Junjun2016 <[email protected]>
1 parent 0b11d58 commit 186a1fc

File tree

12 files changed

+297
-33
lines changed

12 files changed

+297
-33
lines changed

docs/tutorials/training_tricks.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,21 @@ model=dict(
5050
```
5151

5252
`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details.
53+
54+
## Multiple Losses
55+
56+
For loss calculation, we support multiple losses training concurrently. Here is an example config of training `unet` on `DRIVE` dataset, whose loss function is `1:3` weighted sum of `CrossEntropyLoss` and `DiceLoss`:
57+
58+
```python
59+
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
60+
model = dict(
61+
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
62+
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
63+
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
64+
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
65+
)
66+
```
67+
68+
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
69+
70+
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.

docs_zh-CN/tutorials/training_tricks.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,22 @@ model=dict(
4949
```
5050

5151
`class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss)
52+
53+
## 同时使用多种损失函数 (Multiple Losses)
54+
55+
对于训练时损失函数的计算,我们目前支持多个损失函数同时使用。 以 `unet` 使用 `DRIVE` 数据集训练为例,
56+
使用 `CrossEntropyLoss``DiceLoss``1:3` 的加权和作为损失函数。配置文件写为:
57+
58+
```python
59+
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
60+
model = dict(
61+
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
62+
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
63+
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
64+
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
65+
)
66+
```
67+
68+
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
69+
70+
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

mmseg/core/seg/sampler/ohem_pixel_sampler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@ def sample(self, seg_logit, seg_label):
6262
threshold = max(min_threshold, self.thresh)
6363
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
6464
else:
65-
losses = self.context.loss_decode(
66-
seg_logit,
67-
seg_label,
68-
weight=None,
69-
ignore_index=self.context.ignore_index,
70-
reduction_override='none')
65+
losses = 0.0
66+
for loss_module in self.context.loss_decode:
67+
losses += loss_module(
68+
seg_logit,
69+
seg_label,
70+
weight=None,
71+
ignore_index=self.context.ignore_index,
72+
reduction_override='none')
7173
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
7274
_, sort_indices = losses[valid_mask].sort(descending=True)
7375
valid_seg_weight[sort_indices[:batch_kept]] = 1.

mmseg/models/decode_heads/decode_head.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
3333
a list and passed into decode head.
3434
None: Only one select feature map is allowed.
3535
Default: None.
36-
loss_decode (dict): Config of decode loss.
36+
loss_decode (dict | Sequence[dict]): Config of decode loss.
37+
The `loss_name` is property of corresponding loss function which
38+
could be shown in training log. If you want this loss
39+
item to be included into the backward graph, `loss_` must be the
40+
prefix of the name. Defaults to 'loss_ce'.
41+
e.g. dict(type='CrossEntropyLoss'),
42+
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
43+
dict(type='DiceLoss', loss_name='loss_dice')]
3744
Default: dict(type='CrossEntropyLoss').
3845
ignore_index (int | None): The label index to be ignored. When using
39-
masked BCE loss, ignore_index should be set to None. Default: 255
46+
masked BCE loss, ignore_index should be set to None. Default: 255.
4047
sampler (dict|None): The config of segmentation map sampler.
4148
Default: None.
4249
align_corners (bool): align_corners argument of F.interpolate.
@@ -73,9 +80,20 @@ def __init__(self,
7380
self.norm_cfg = norm_cfg
7481
self.act_cfg = act_cfg
7582
self.in_index = in_index
76-
self.loss_decode = build_loss(loss_decode)
83+
7784
self.ignore_index = ignore_index
7885
self.align_corners = align_corners
86+
self.loss_decode = nn.ModuleList()
87+
88+
if isinstance(loss_decode, dict):
89+
self.loss_decode.append(build_loss(loss_decode))
90+
elif isinstance(loss_decode, (list, tuple)):
91+
for loss in loss_decode:
92+
self.loss_decode.append(build_loss(loss))
93+
else:
94+
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
95+
but got {type(loss_decode)}')
96+
7997
if sampler is not None:
8098
self.sampler = build_pixel_sampler(sampler, context=self)
8199
else:
@@ -224,10 +242,19 @@ def losses(self, seg_logit, seg_label):
224242
else:
225243
seg_weight = None
226244
seg_label = seg_label.squeeze(1)
227-
loss['loss_seg'] = self.loss_decode(
228-
seg_logit,
229-
seg_label,
230-
weight=seg_weight,
231-
ignore_index=self.ignore_index)
245+
for loss_decode in self.loss_decode:
246+
if loss_decode.loss_name not in loss:
247+
loss[loss_decode.loss_name] = loss_decode(
248+
seg_logit,
249+
seg_label,
250+
weight=seg_weight,
251+
ignore_index=self.ignore_index)
252+
else:
253+
loss[loss_decode.loss_name] += loss_decode(
254+
seg_logit,
255+
seg_label,
256+
weight=seg_weight,
257+
ignore_index=self.ignore_index)
258+
232259
loss['acc_seg'] = accuracy(seg_logit, seg_label)
233260
return loss

mmseg/models/decode_heads/point_head.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,9 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
249249
def losses(self, point_logits, point_label):
250250
"""Compute segmentation loss."""
251251
loss = dict()
252-
loss['loss_point'] = self.loss_decode(
253-
point_logits, point_label, ignore_index=self.ignore_index)
252+
for loss_module in self.loss_decode:
253+
loss['point' + loss_module.loss_name] = loss_module(
254+
point_logits, point_label, ignore_index=self.ignore_index)
254255
loss['acc_point'] = accuracy(point_logits, point_label)
255256
return loss
256257

mmseg/models/losses/cross_entropy_loss.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,18 @@ class CrossEntropyLoss(nn.Module):
150150
class_weight (list[float] | str, optional): Weight of each class. If in
151151
str format, read them from a file. Defaults to None.
152152
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
153+
loss_name (str, optional): Name of the loss item. If you want this loss
154+
item to be included into the backward graph, `loss_` must be the
155+
prefix of the name. Defaults to 'loss_ce'.
153156
"""
154157

155158
def __init__(self,
156159
use_sigmoid=False,
157160
use_mask=False,
158161
reduction='mean',
159162
class_weight=None,
160-
loss_weight=1.0):
163+
loss_weight=1.0,
164+
loss_name='loss_ce'):
161165
super(CrossEntropyLoss, self).__init__()
162166
assert (use_sigmoid is False) or (use_mask is False)
163167
self.use_sigmoid = use_sigmoid
@@ -172,6 +176,7 @@ def __init__(self,
172176
self.cls_criterion = mask_cross_entropy
173177
else:
174178
self.cls_criterion = cross_entropy
179+
self._loss_name = loss_name
175180

176181
def forward(self,
177182
cls_score,
@@ -197,3 +202,17 @@ def forward(self,
197202
avg_factor=avg_factor,
198203
**kwargs)
199204
return loss_cls
205+
206+
@property
207+
def loss_name(self):
208+
"""Loss Name.
209+
210+
This function must be implemented and will return the name of this
211+
loss function. This name will be used to combine different loss items
212+
by simple sum operation. In addition, if you want this loss item to be
213+
included into the backward graph, `loss_` must be the prefix of the
214+
name.
215+
Returns:
216+
str: The name of this loss item.
217+
"""
218+
return self._loss_name

mmseg/models/losses/dice_loss.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class DiceLoss(nn.Module):
6868
str format, read them from a file. Defaults to None.
6969
loss_weight (float, optional): Weight of the loss. Default to 1.0.
7070
ignore_index (int | None): The label index to be ignored. Default: 255.
71+
loss_name (str, optional): Name of the loss item. If you want this loss
72+
item to be included into the backward graph, `loss_` must be the
73+
prefix of the name. Defaults to 'loss_dice'.
7174
"""
7275

7376
def __init__(self,
@@ -77,6 +80,7 @@ def __init__(self,
7780
class_weight=None,
7881
loss_weight=1.0,
7982
ignore_index=255,
83+
loss_name='loss_dice',
8084
**kwards):
8185
super(DiceLoss, self).__init__()
8286
self.smooth = smooth
@@ -85,6 +89,7 @@ def __init__(self,
8589
self.class_weight = get_class_weight(class_weight)
8690
self.loss_weight = loss_weight
8791
self.ignore_index = ignore_index
92+
self._loss_name = loss_name
8893

8994
def forward(self,
9095
pred,
@@ -118,3 +123,17 @@ def forward(self,
118123
class_weight=class_weight,
119124
ignore_index=self.ignore_index)
120125
return loss
126+
127+
@property
128+
def loss_name(self):
129+
"""Loss Name.
130+
131+
This function must be implemented and will return the name of this
132+
loss function. This name will be used to combine different loss items
133+
by simple sum operation. In addition, if you want this loss item to be
134+
included into the backward graph, `loss_` must be the prefix of the
135+
name.
136+
Returns:
137+
str: The name of this loss item.
138+
"""
139+
return self._loss_name

mmseg/models/losses/lovasz_loss.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ class LovaszLoss(nn.Module):
244244
class_weight (list[float] | str, optional): Weight of each class. If in
245245
str format, read them from a file. Defaults to None.
246246
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
247+
loss_name (str, optional): Name of the loss item. If you want this loss
248+
item to be included into the backward graph, `loss_` must be the
249+
prefix of the name. Defaults to 'loss_lovasz'.
247250
"""
248251

249252
def __init__(self,
@@ -252,7 +255,8 @@ def __init__(self,
252255
per_image=False,
253256
reduction='mean',
254257
class_weight=None,
255-
loss_weight=1.0):
258+
loss_weight=1.0,
259+
loss_name='loss_lovasz'):
256260
super(LovaszLoss, self).__init__()
257261
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
258262
'binary' or 'multi_class'."
@@ -271,6 +275,7 @@ def __init__(self,
271275
self.reduction = reduction
272276
self.loss_weight = loss_weight
273277
self.class_weight = get_class_weight(class_weight)
278+
self._loss_name = loss_name
274279

275280
def forward(self,
276281
cls_score,
@@ -302,3 +307,17 @@ def forward(self,
302307
avg_factor=avg_factor,
303308
**kwargs)
304309
return loss_cls
310+
311+
@property
312+
def loss_name(self):
313+
"""Loss Name.
314+
315+
This function must be implemented and will return the name of this
316+
loss function. This name will be used to combine different loss items
317+
by simple sum operation. In addition, if you want this loss item to be
318+
included into the backward graph, `loss_` must be the prefix of the
319+
name.
320+
Returns:
321+
str: The name of this loss item.
322+
"""
323+
return self._loss_name

tests/test_models/test_heads/test_decode_head.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,92 @@ def test_decode_head():
7474
assert head.input_transform == 'resize_concat'
7575
transformed_inputs = head._transform_inputs(inputs)
7676
assert transformed_inputs.shape == (1, 48, 45, 45)
77+
78+
# test multi-loss, loss_decode is dict
79+
with pytest.raises(TypeError):
80+
# loss_decode must be a dict or sequence of dict.
81+
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
82+
83+
inputs = torch.randn(2, 19, 8, 8).float()
84+
target = torch.ones(2, 1, 64, 64).long()
85+
head = BaseDecodeHead(
86+
3,
87+
16,
88+
num_classes=19,
89+
loss_decode=dict(
90+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
91+
if torch.cuda.is_available():
92+
head, inputs = to_cuda(head, inputs)
93+
head, target = to_cuda(head, target)
94+
loss = head.losses(seg_logit=inputs, seg_label=target)
95+
assert 'loss_ce' in loss
96+
97+
# test multi-loss, loss_decode is list of dict
98+
inputs = torch.randn(2, 19, 8, 8).float()
99+
target = torch.ones(2, 1, 64, 64).long()
100+
head = BaseDecodeHead(
101+
3,
102+
16,
103+
num_classes=19,
104+
loss_decode=[
105+
dict(type='CrossEntropyLoss', loss_name='loss_1'),
106+
dict(type='CrossEntropyLoss', loss_name='loss_2')
107+
])
108+
if torch.cuda.is_available():
109+
head, inputs = to_cuda(head, inputs)
110+
head, target = to_cuda(head, target)
111+
loss = head.losses(seg_logit=inputs, seg_label=target)
112+
assert 'loss_1' in loss
113+
assert 'loss_2' in loss
114+
115+
# 'loss_decode' must be a dict or sequence of dict
116+
with pytest.raises(TypeError):
117+
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
118+
with pytest.raises(TypeError):
119+
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
120+
121+
# test multi-loss, loss_decode is list of dict
122+
inputs = torch.randn(2, 19, 8, 8).float()
123+
target = torch.ones(2, 1, 64, 64).long()
124+
head = BaseDecodeHead(
125+
3,
126+
16,
127+
num_classes=19,
128+
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
129+
dict(type='CrossEntropyLoss', loss_name='loss_2'),
130+
dict(type='CrossEntropyLoss', loss_name='loss_3')))
131+
if torch.cuda.is_available():
132+
head, inputs = to_cuda(head, inputs)
133+
head, target = to_cuda(head, target)
134+
loss = head.losses(seg_logit=inputs, seg_label=target)
135+
assert 'loss_1' in loss
136+
assert 'loss_2' in loss
137+
assert 'loss_3' in loss
138+
139+
# test multi-loss, loss_decode is list of dict, names of them are identical
140+
inputs = torch.randn(2, 19, 8, 8).float()
141+
target = torch.ones(2, 1, 64, 64).long()
142+
head = BaseDecodeHead(
143+
3,
144+
16,
145+
num_classes=19,
146+
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
147+
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
148+
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
149+
if torch.cuda.is_available():
150+
head, inputs = to_cuda(head, inputs)
151+
head, target = to_cuda(head, target)
152+
loss_3 = head.losses(seg_logit=inputs, seg_label=target)
153+
154+
head = BaseDecodeHead(
155+
3,
156+
16,
157+
num_classes=19,
158+
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
159+
if torch.cuda.is_available():
160+
head, inputs = to_cuda(head, inputs)
161+
head, target = to_cuda(head, target)
162+
loss = head.losses(seg_logit=inputs, seg_label=target)
163+
assert 'loss_ce' in loss
164+
assert 'loss_ce' in loss_3
165+
assert loss_3['loss_ce'] == 3 * loss['loss_ce']

0 commit comments

Comments
 (0)