Skip to content

Commit 11a3e35

Browse files
authored
[Fix] Compatible with mme in model loading pretrained. (#1329)
* [Fix] Compatible with mme in model loading pretrained. * add unittest of pretrained * fix unittest error
1 parent 3336abf commit 11a3e35

File tree

10 files changed

+47
-15
lines changed

10 files changed

+47
-15
lines changed

mmpose/models/detectors/associative_embedding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __init__(self,
6565
self.test_cfg = test_cfg
6666
self.use_udp = test_cfg.get('use_udp', False)
6767
self.parser = HeatmapParser(self.test_cfg)
68-
self.init_weights(pretrained=pretrained)
68+
self.pretrained = pretrained
69+
self.init_weights()
6970

7071
@property
7172
def with_keypoint(self):
@@ -74,7 +75,9 @@ def with_keypoint(self):
7475

7576
def init_weights(self, pretrained=None):
7677
"""Weight initialization for model."""
77-
self.backbone.init_weights(pretrained)
78+
if pretrained is not None:
79+
self.pretrained = pretrained
80+
self.backbone.init_weights(self.pretrained)
7881
if self.with_keypoint:
7982
self.keypoint_head.init_weights()
8083

mmpose/models/detectors/mesh.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,14 @@ def __init__(self,
7272
self.test_cfg = test_cfg
7373

7474
self.loss_mesh = builder.build_loss(loss_mesh)
75-
self.init_weights(pretrained=pretrained)
75+
self.pretrained = pretrained
76+
self.init_weights()
7677

7778
def init_weights(self, pretrained=None):
7879
"""Weight initialization for model."""
79-
self.backbone.init_weights(pretrained)
80+
if pretrained is not None:
81+
self.pretrained = pretrained
82+
self.backbone.init_weights(self.pretrained)
8083
self.mesh_head.init_weights()
8184
if self.with_gan:
8285
self.discriminator.init_weights()

mmpose/models/detectors/multi_task.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(self,
4646
for head in heads:
4747
assert head is not None
4848
self.heads.append(builder.build_head(head))
49-
50-
self.init_weights(pretrained=pretrained)
49+
self.pretrained = pretrained
50+
self.init_weights()
5151

5252
@property
5353
def with_necks(self):
@@ -56,7 +56,9 @@ def with_necks(self):
5656

5757
def init_weights(self, pretrained=None):
5858
"""Weight initialization for model."""
59-
self.backbone.init_weights(pretrained)
59+
if pretrained is not None:
60+
self.pretrained = pretrained
61+
self.backbone.init_weights(self.pretrained)
6062
if self.with_necks:
6163
for neck in self.necks:
6264
if hasattr(neck, 'init_weights'):

mmpose/models/detectors/pose_lifter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def __init__(self,
8686
if self.semi:
8787
assert keypoint_head is not None and traj_head is not None
8888
self.loss_semi = builder.build_loss(loss_semi)
89-
90-
self.init_weights(pretrained=pretrained)
89+
self.pretrained = pretrained
90+
self.init_weights()
9191

9292
@property
9393
def with_neck(self):
@@ -125,13 +125,15 @@ def causal(self):
125125

126126
def init_weights(self, pretrained=None):
127127
"""Weight initialization for model."""
128-
self.backbone.init_weights(pretrained)
128+
if pretrained is not None:
129+
self.pretrained = pretrained
130+
self.backbone.init_weights(self.pretrained)
129131
if self.with_neck:
130132
self.neck.init_weights()
131133
if self.with_keypoint:
132134
self.keypoint_head.init_weights()
133135
if self.with_traj_backbone:
134-
self.traj_backbone.init_weights(pretrained)
136+
self.traj_backbone.init_weights(self.pretrained)
135137
if self.with_traj_neck:
136138
self.traj_neck.init_weights()
137139
if self.with_traj:

mmpose/models/detectors/top_down.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(self,
6666
keypoint_head['loss_keypoint'] = loss_pose
6767

6868
self.keypoint_head = builder.build_head(keypoint_head)
69-
70-
self.init_weights(pretrained=pretrained)
69+
self.pretrained = pretrained
70+
self.init_weights()
7171

7272
@property
7373
def with_neck(self):
@@ -81,7 +81,9 @@ def with_keypoint(self):
8181

8282
def init_weights(self, pretrained=None):
8383
"""Weight initialization for model."""
84-
self.backbone.init_weights(pretrained)
84+
if pretrained is not None:
85+
self.pretrained = pretrained
86+
self.backbone.init_weights(self.pretrained)
8587
if self.with_neck:
8688
self.neck.init_weights()
8789
if self.with_keypoint:

tests/test_models/test_bottom_up_forward.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import numpy as np
3+
import pytest
34
import torch
45

56
from mmpose.models.detectors import AssociativeEmbedding
@@ -60,6 +61,9 @@ def test_ae_forward():
6061
model_cfg['test_cfg'],
6162
model_cfg['pretrained'])
6263

64+
with pytest.raises(TypeError):
65+
detector.init_weights(pretrained=dict())
66+
detector.pretrained = model_cfg['pretrained']
6367
detector.init_weights()
6468

6569
input_shape = (1, 3, 256, 256)

tests/test_models/test_mesh_forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44

55
import numpy as np
6+
import pytest
67
import torch
78

89
from mmpose.core.optimizer import build_optimizers
@@ -49,6 +50,10 @@ def test_parametric_mesh_forward():
4950
loss_gan=None)
5051

5152
detector = ParametricMesh(**model_cfg)
53+
54+
with pytest.raises(TypeError):
55+
detector.init_weights(pretrained=dict())
56+
detector.pretrained = model_cfg['pretrained']
5257
detector.init_weights()
5358

5459
optimizers_config = dict(generator=dict(type='Adam', lr=0.0001))

tests/test_models/test_multitask_forward.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import numpy as np
3+
import pytest
34
import torch
45

56
from mmpose.models.detectors import MultiTask
@@ -24,7 +25,9 @@ def test_multitask_forward():
2425
pretrained=None,
2526
)
2627
model = MultiTask(**model_cfg)
27-
28+
with pytest.raises(TypeError):
29+
model.init_weights(pretrained=dict())
30+
model.pretrained = model_cfg['pretrained']
2831
# build inputs and target
2932
mm_inputs = _demo_mm_inputs()
3033
inputs = mm_inputs['img']

tests/test_models/test_pose_lifter_forward.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import mmcv
33
import numpy as np
4+
import pytest
45
import torch
56

67
from mmpose.models import build_posenet
@@ -71,6 +72,9 @@ def test_pose_lifter_forward():
7172
cfg = mmcv.Config({'model': model_cfg})
7273
detector = build_posenet(cfg.model)
7374

75+
with pytest.raises(TypeError):
76+
detector.init_weights(pretrained=dict())
77+
detector.pretrained = model_cfg['pretrained']
7478
detector.init_weights()
7579

7680
inputs = _create_inputs(

tests/test_models/test_top_down_forward.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33

44
import numpy as np
5+
import pytest
56
import torch
67

78
from mmpose.models.detectors import PoseWarper, TopDown
@@ -79,6 +80,9 @@ def test_topdown_forward():
7980
model_cfg['train_cfg'], model_cfg['test_cfg'],
8081
model_cfg['pretrained'])
8182

83+
with pytest.raises(TypeError):
84+
detector.init_weights(pretrained=dict())
85+
detector.pretrained = model_cfg['pretrained']
8286
detector.init_weights()
8387

8488
input_shape = (1, 3, 256, 256)

0 commit comments

Comments
 (0)