Skip to content

Commit d3cfde9

Browse files
Add interhand3d detector, head and config. (#624)
* Add interhand3d detector, head and config. Fix some bugs in multilabel_classification_head and interhand3d_dataset. * modify the url of ckpt and log of InterNet.
1 parent ebaede6 commit d3cfde9

22 files changed

+1213
-666
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Supported algorithms:
8080
- [x] [RSN](configs/top_down/rsn/README.md) (ECCV'2020)
8181
- [x] [HMR](configs/mesh/hmr/README.md) (CVPR'2018)
8282
- [x] [Simple 3D Baseline](configs/body3d/simple_baseline/README.md) (ICCV'2017)
83+
- [x] [InterNet](configs/hand3d/InterNet/README.md) (ECCV'2020)
8384

8485
</details>
8586

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ MMPose 是一款基于 PyTorch 的姿态分析的开源工具箱,是 [OpenMMLa
8181
- [x] [RSN](configs/top_down/rsn/README.md) (ECCV'2020)
8282
- [x] [HMR](configs/mesh/hmr/README.md) (CVPR'2018)
8383
- [x] [Simple 3D Baseline](configs/body3d/simple_baseline/README.md) (ICCV'2017)
84+
- [x] [InterNet](configs/hand3d/InterNet/README.md) (ECCV'2020)
8485

8586
</details>
8687

configs/hand3d/InterNet/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# InterHand2.6M: A Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single RGB Image
2+
3+
## Introduction
4+
5+
<!-- [ALGORITHM] -->
6+
7+
```bibtex
8+
@InProceedings{Moon_2020_ECCV_InterHand2.6M,
9+
author = {Moon, Gyeongsik and Yu, Shoou-I and Wen, He and Shiratori, Takaaki and Lee, Kyoung Mu},
10+
title = {InterHand2.6M: A Dataset and Baseline for 3D Interacting Hand Pose Estimation from a Single RGB Image},
11+
booktitle = {European Conference on Computer Vision (ECCV)},
12+
year = {2020}
13+
}
14+
```
15+
16+
## Results and models
17+
18+
### 3d Hand Pose Estimation
19+
20+
#### Results on InterHand2.6M val & test set
21+
22+
|Train Set| Set | Arch | Input Size | MPJPE-single | MPJPE-interacting | MPJPE-all | MRRPE | APh | ckpt | log |
23+
| :--- | :--- | :--------: | :--------: | :------: | :------: | :------: |:------: |:------: |:------: |:------: |
24+
| All | test(H+M) | [InterNet_resnet_50](/configs/hand3d/InterNet/interhand3d/res50_interhand3d_all_256x256.py) | 256x256 | 10.16 | 15.27 | 12.97 | 33.14 | 0.99 | [ckpt](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3d_all_256x256-b9c1cf4c_20210506.pth) | [log](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3d_all_256x256_20210506.log.json) |
25+
| All | val(M) | [InterNet_resnet_50](/configs/hand3d/InterNet/interhand3d/res50_interhand3d_all_256x256.py) | 256x256 | 12.03 | 17.88 | 14.84 | 34.93 | 0.99 | [ckpt](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3d_all_256x256-b9c1cf4c_20210506.pth) | [log](https://download.openmmlab.com/mmpose/hand3d/internet/res50_intehand3d_all_256x256_20210506.log.json) |
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
log_level = 'INFO'
2+
load_from = None
3+
resume_from = None
4+
dist_params = dict(backend='nccl')
5+
workflow = [('train', 1)]
6+
checkpoint_config = dict(interval=1)
7+
evaluation = dict(
8+
interval=1,
9+
metric=['MRRPE', 'MPJPE', 'Handedness_acc'],
10+
key_indicator='MPJPE_all')
11+
12+
optimizer = dict(
13+
type='Adam',
14+
lr=2e-4,
15+
)
16+
optimizer_config = dict(grad_clip=None)
17+
# learning policy
18+
lr_config = dict(policy='step', step=[15, 17])
19+
total_epochs = 20
20+
log_config = dict(
21+
interval=20,
22+
hooks=[
23+
dict(type='TextLoggerHook'),
24+
# dict(type='TensorboardLoggerHook')
25+
])
26+
27+
channel_cfg = dict(
28+
num_output_channels=42,
29+
dataset_joints=42,
30+
dataset_channel=[list(range(42))],
31+
inference_channel=list(range(42)))
32+
33+
# model settings
34+
model = dict(
35+
type='Interhand3D',
36+
pretrained='torchvision://resnet50',
37+
backbone=dict(type='ResNet', depth=50),
38+
keypoint_head=dict(
39+
type='Interhand3DHead',
40+
keypoint_head_cfg=dict(
41+
in_channels=2048,
42+
out_channels=21 * 64,
43+
depth_size=64,
44+
num_deconv_layers=3,
45+
num_deconv_filters=(256, 256, 256),
46+
num_deconv_kernels=(4, 4, 4),
47+
),
48+
root_head_cfg=dict(
49+
in_channels=2048,
50+
heatmap_size=64,
51+
hidden_dims=(512, ),
52+
),
53+
hand_type_head_cfg=dict(
54+
in_channels=2048,
55+
num_labels=2,
56+
hidden_dims=(512, ),
57+
),
58+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True),
59+
loss_root_depth=dict(type='L1Loss'),
60+
loss_hand_type=dict(type='BCELoss', use_target_weight=True),
61+
),
62+
train_cfg={},
63+
test_cfg=dict(flip_test=False))
64+
65+
data_cfg = dict(
66+
image_size=[256, 256],
67+
heatmap_size=[64, 64, 64],
68+
heatmap3d_depth_bound=400.0,
69+
heatmap_size_root=64,
70+
root_depth_bound=400.0,
71+
num_output_channels=channel_cfg['num_output_channels'],
72+
num_joints=channel_cfg['dataset_joints'],
73+
dataset_channel=channel_cfg['dataset_channel'],
74+
inference_channel=channel_cfg['inference_channel'])
75+
76+
train_pipeline = [
77+
dict(type='LoadImageFromFile'),
78+
dict(type='HandRandomFlip', flip_prob=0.5),
79+
dict(type='TopDownRandomTranslation', trans_factor=0.15),
80+
dict(
81+
type='TopDownGetRandomScaleRotation',
82+
rot_factor=45,
83+
scale_factor=0.25,
84+
rot_prob=0.6),
85+
# dict(type='MeshRandomChannelNoise', noise_factor=0.2),
86+
dict(type='TopDownAffine'),
87+
dict(type='ToTensor'),
88+
dict(
89+
type='NormalizeTensor',
90+
mean=[0.485, 0.456, 0.406],
91+
std=[0.229, 0.224, 0.225]),
92+
dict(
93+
type='MultitaskGatherTarget',
94+
pipeline_list=[
95+
[dict(
96+
type='Generate3DHeatmapTarget',
97+
sigma=2.5,
98+
max_bound=255,
99+
)], [dict(type='HandGenerateRelDepthTarget')],
100+
[
101+
dict(
102+
type='RenameKeys',
103+
key_pairs=[('hand_type', 'target'),
104+
('hand_type_valid', 'target_weight')])
105+
]
106+
],
107+
pipeline_indices=[0, 1, 2],
108+
),
109+
dict(
110+
type='Collect',
111+
keys=['img', 'target', 'target_weight'],
112+
meta_keys=[
113+
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
114+
'rotation', 'flip_pairs'
115+
]),
116+
]
117+
118+
val_pipeline = [
119+
dict(type='LoadImageFromFile'),
120+
dict(type='TopDownAffine'),
121+
dict(type='ToTensor'),
122+
dict(
123+
type='NormalizeTensor',
124+
mean=[0.485, 0.456, 0.406],
125+
std=[0.229, 0.224, 0.225]),
126+
dict(
127+
type='Collect',
128+
keys=['img'],
129+
meta_keys=[
130+
'image_file', 'center', 'scale', 'rotation', 'flip_pairs',
131+
'heatmap3d_depth_bound', 'root_depth_bound'
132+
]),
133+
]
134+
135+
test_pipeline = val_pipeline
136+
137+
data_root = 'data/interhand2.6m'
138+
data = dict(
139+
samples_per_gpu=16,
140+
workers_per_gpu=2,
141+
train=dict(
142+
type='InterHand3DDataset',
143+
ann_file=f'{data_root}/annotations/all/'
144+
'InterHand2.6M_train_data.json',
145+
camera_file=f'{data_root}/annotations/all/'
146+
'InterHand2.6M_train_camera.json',
147+
joint_file=f'{data_root}/annotations/all/'
148+
'InterHand2.6M_train_joint_3d.json',
149+
img_prefix=f'{data_root}/images/train/',
150+
data_cfg=data_cfg,
151+
use_gt_root_depth=True,
152+
rootnet_result_file=None,
153+
pipeline=train_pipeline),
154+
val=dict(
155+
type='InterHand3DDataset',
156+
ann_file=f'{data_root}/annotations/machine_annot/'
157+
'InterHand2.6M_val_data.json',
158+
camera_file=f'{data_root}/annotations/machine_annot/'
159+
'InterHand2.6M_val_camera.json',
160+
joint_file=f'{data_root}/annotations/machine_annot/'
161+
'InterHand2.6M_val_joint_3d.json',
162+
img_prefix=f'{data_root}/images/val/',
163+
data_cfg=data_cfg,
164+
use_gt_root_depth=True,
165+
rootnet_result_file=None,
166+
pipeline=val_pipeline),
167+
test=dict(
168+
type='InterHand3DDataset',
169+
ann_file=f'{data_root}/annotations/all/'
170+
'InterHand2.6M_test_data.json',
171+
camera_file=f'{data_root}/annotations/all/'
172+
'InterHand2.6M_test_camera.json',
173+
joint_file=f'{data_root}/annotations/all/'
174+
'InterHand2.6M_test_joint_3d.json',
175+
img_prefix=f'{data_root}/images/test/',
176+
data_cfg=data_cfg,
177+
use_gt_root_depth=True,
178+
rootnet_result_file=None,
179+
pipeline=val_pipeline),
180+
)

mmpose/core/evaluation/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
from .mesh_eval import compute_similarity_transform
55
from .pose3d_eval import keypoint_mpjpe
66
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_pck_accuracy,
7-
keypoints_from_heatmaps, keypoints_from_regression,
7+
keypoints_from_heatmaps, keypoints_from_heatmaps3d,
8+
keypoints_from_regression,
9+
multilabel_classification_accuracy,
810
pose_pck_accuracy, post_dark_udp)
911

1012
__all__ = [
1113
'EvalHook', 'DistEvalHook', 'pose_pck_accuracy', 'keypoints_from_heatmaps',
1214
'keypoints_from_regression', 'keypoint_pck_accuracy', 'keypoint_auc',
1315
'keypoint_epe', 'get_group_preds', 'get_multi_stage_outputs',
1416
'aggregate_results', 'compute_similarity_transform', 'post_dark_udp',
15-
'keypoint_mpjpe'
17+
'keypoint_mpjpe', 'keypoints_from_heatmaps3d',
18+
'multilabel_classification_accuracy'
1619
]

mmpose/core/evaluation/top_down_eval.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,43 @@ def _get_max_preds(heatmaps):
9191
return preds, maxvals
9292

9393

94+
def _get_max_preds_3d(heatmaps):
95+
"""Get keypoint predictions from 3D score maps.
96+
97+
Note:
98+
batch size: N
99+
num keypoints: K
100+
heatmap depth size: D
101+
heatmap height: H
102+
heatmap width: W
103+
104+
Args:
105+
heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
106+
107+
Returns:
108+
tuple: A tuple containing aggregated results.
109+
- preds (np.ndarray[N, K, 3]): Predicted keypoint location.
110+
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
111+
"""
112+
assert isinstance(heatmaps, np.ndarray), \
113+
('heatmaps should be numpy.ndarray')
114+
assert heatmaps.ndim == 5, 'heatmaps should be 5-ndim'
115+
116+
N, K, D, H, W = heatmaps.shape
117+
heatmaps_reshaped = heatmaps.reshape((N, K, -1))
118+
idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))
119+
maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))
120+
121+
preds = np.zeros((N, K, 3), dtype=np.float32)
122+
_idx = idx[..., 0]
123+
preds[..., 2] = _idx // (H * W)
124+
preds[..., 1] = (_idx // W) % H
125+
preds[..., 0] = _idx % W
126+
127+
preds = np.where(maxvals > 0.0, preds, -1)
128+
return preds, maxvals
129+
130+
94131
def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None):
95132
"""Calculate the pose accuracy of PCK for each individual keypoint and the
96133
averaged accuracy across all keypoints from heatmaps.
@@ -574,3 +611,64 @@ def keypoints_from_heatmaps(heatmaps,
574611
maxvals = maxvals / 255.0 + 0.5
575612

576613
return preds, maxvals
614+
615+
616+
def keypoints_from_heatmaps3d(heatmaps, center, scale):
617+
"""Get final keypoint predictions from 3d heatmaps and transform them back
618+
to the image.
619+
620+
Note:
621+
batch size: N
622+
num keypoints: K
623+
heatmap depth size: D
624+
heatmap height: H
625+
heatmap width: W
626+
627+
Args:
628+
heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
629+
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
630+
scale (np.ndarray[N, 2]): Scale of the bounding box
631+
wrt height/width.
632+
633+
Returns:
634+
tuple: A tuple containing keypoint predictions and scores.
635+
636+
- preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location
637+
in images.
638+
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
639+
"""
640+
N, K, D, H, W = heatmaps.shape
641+
preds, maxvals = _get_max_preds_3d(heatmaps)
642+
# Transform back to the image
643+
for i in range(N):
644+
preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i],
645+
[W, H])
646+
return preds, maxvals
647+
648+
649+
def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
650+
"""Get multi-label classification accuracy.
651+
Notes:
652+
batch size: N
653+
label number: L
654+
655+
Args:
656+
pred (np.ndarray[N, L, 2]): model predicted labels.
657+
gt (np.ndarray[N, L, 2]): ground-truth labels.
658+
mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of
659+
ground-truth labels.
660+
661+
Returns:
662+
acc (float): multi-label classification accuracy.
663+
"""
664+
# we only compute accuracy on the samples with ground-truth of all labels.
665+
valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0)
666+
pred, gt = pred[valid], gt[valid]
667+
668+
if pred.shape[0] == 0:
669+
acc = 0 # when no sample is with gt labels, set acc to 0.
670+
else:
671+
# The classification of a sample is regarded as correct
672+
# only if it's correct for all labels.
673+
acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean()
674+
return acc

mmpose/datasets/datasets/hand/interhand3d_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ def _pixel2cam(pixel_coord, f, c):
188188
@staticmethod
189189
def _encode_handtype(hand_type):
190190
if hand_type == 'right':
191-
return np.array([1, 0], dtype=int)
191+
return np.array([1, 0], dtype=np.float32)
192192
elif hand_type == 'left':
193-
return np.array([0, 1], dtype=int)
193+
return np.array([0, 1], dtype=np.float32)
194194
elif hand_type == 'interacting':
195-
return np.array([1, 1], dtype=int)
195+
return np.array([1, 1], dtype=np.float32)
196196
else:
197197
assert 0, f'Not support hand type: {hand_type}'
198198

@@ -375,7 +375,7 @@ def evaluate(self, outputs, res_folder, metric='MPJPE', **kwargs):
375375
}
376376

377377
if preds is not None:
378-
kpt['keypoints'] = preds[i].tolist()
378+
kpt['keypoints'] = preds[i, :, :3].tolist()
379379
if hand_type is not None:
380380
kpt['hand_type'] = hand_type[i][0:2].tolist()
381381
kpt['hand_type_score'] = hand_type[i][2:4].tolist()

0 commit comments

Comments
 (0)