Skip to content

Commit 5590b41

Browse files
authored
[pose-detection]Change Blazepose fullbody and upperbody selection fro… (#658)
PROCESS
1 parent 30d3aba commit 5590b41

File tree

12 files changed

+90
-73
lines changed

12 files changed

+90
-73
lines changed

pose-detection/demo/src/index.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ async function createDetector() {
3737
inputResolution: {width: 500, height: 500},
3838
multiplier: 0.75
3939
});
40-
case posedetection.SupportedModels.MediapipeBlazepose:
41-
return posedetection.createDetector(
42-
STATE.model.model, {quantBytes: 4, upperBodyOnly: false});
40+
case posedetection.SupportedModels.MediapipeBlazeposeUpperBody:
41+
case posedetection.SupportedModels.MediapipeBlazeposeFullBody:
42+
return posedetection.createDetector(STATE.model.model, {quantBytes: 4});
4343
case posedetection.SupportedModels.MoveNet:
4444
const modelType = STATE.model.type == 'lightning' ?
4545
posedetection.movenet.modelType.SINGLEPOSE_LIGHTNING :

pose-detection/demo/src/option_panel.js

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* =============================================================================
1616
*/
1717
import * as posedetection from '@tensorflow-models/pose-detection';
18+
import {type} from 'os';
1819

1920
import * as params from './params';
2021

@@ -81,11 +82,23 @@ function addMoveNetControllers(modelFolder, type) {
8182

8283
// The Blazepose model config folder contains options for Blazepose config
8384
// settings.
84-
function addBlazePoseControllers(modelFolder) {
85-
params.STATE.model = {
86-
model: posedetection.SupportedModels.MediapipeBlazepose,
87-
...params.BLAZEPOSE_CONFIG
88-
};
85+
function addBlazePoseControllers(modelFolder, type) {
86+
params.STATE.model = {...params.BLAZEPOSE_CONFIG};
87+
88+
params.STATE.model.model = type === 'upperbody' ?
89+
posedetection.SupportedModels.MediapipeBlazeposeUpperBody :
90+
posedetection.SupportedModels.MediapipeBlazeposeFullBody;
91+
92+
params.STATE.model.type = type === 'upperbody' ? 'upperbody' : 'fullbody';
93+
94+
const typeController =
95+
modelFolder.add(params.STATE.model, 'type', ['fullbody', 'upperbody']);
96+
typeController.onChange(type => {
97+
params.STATE.changeToModel = type;
98+
params.STATE.model.model = type === 'upperbody' ?
99+
posedetection.SupportedModels.MediapipeBlazeposeUpperBody :
100+
posedetection.SupportedModels.MediapipeBlazeposeFullBody;
101+
})
89102

90103
modelFolder.add(params.STATE.model, 'scoreThreshold', 0, 1);
91104
}

pose-detection/src/blazepose/blazepose_test.ts

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util';
2424
import * as poseDetection from '../index';
2525
import {getXYPerFrame, KARMA_SERVER, loadImage, loadVideo} from '../test_util';
2626

27-
const UPPERBODY_ONLY = [false, true];
27+
const MODEL_LIST = [
28+
poseDetection.SupportedModels.MediapipeBlazeposeUpperBody,
29+
poseDetection.SupportedModels.MediapipeBlazeposeFullBody
30+
];
2831
const EPSILON_IMAGE = 10;
2932
const EPSILON_VIDEO = 50;
3033
// ref:
@@ -66,10 +69,9 @@ describeWithFlags('Blazepose', ALL_ENVS, () => {
6669
// Note: this makes a network request for model assets.
6770
const modelConfig: poseDetection.BlazeposeModelConfig = {
6871
quantBytes: 4,
69-
upperBodyOnly: false
7072
};
7173
detector = await poseDetection.createDetector(
72-
poseDetection.SupportedModels.MediapipeBlazepose, modelConfig);
74+
poseDetection.SupportedModels.MediapipeBlazeposeFullBody, modelConfig);
7375
});
7476

7577
it('estimatePoses does not leak memory', async () => {
@@ -104,15 +106,13 @@ describeWithFlags('Blazepose static image ', BROWSER_ENVS, () => {
104106
jasmine.DEFAULT_TIMEOUT_INTERVAL = timeout;
105107
});
106108

107-
UPPERBODY_ONLY.forEach(upperBodyOnly => {
109+
MODEL_LIST.forEach(model => {
108110
it('test.', async () => {
109111
const startTensors = tf.memory().numTensors;
110112

111113
// Note: this makes a network request for model assets.
112-
const modelConfig:
113-
poseDetection.BlazeposeModelConfig = {quantBytes: 4, upperBodyOnly};
114-
detector = await poseDetection.createDetector(
115-
poseDetection.SupportedModels.MediapipeBlazepose, modelConfig);
114+
const modelConfig: poseDetection.BlazeposeModelConfig = {quantBytes: 4};
115+
detector = await poseDetection.createDetector(model, modelConfig);
116116

117117
const beforeTensors = tf.memory().numTensors;
118118

@@ -122,8 +122,10 @@ describeWithFlags('Blazepose static image ', BROWSER_ENVS, () => {
122122
poseDetection.BlazeposeEstimationConfig);
123123
const xy =
124124
result[0].keypoints.map((keypoint) => [keypoint.x, keypoint.y]);
125-
const expected = upperBodyOnly ? EXPECTED_UPPERBODY_LANDMARKS :
126-
EXPECTED_FULLBODY_LANDMARKS;
125+
const expected =
126+
model === poseDetection.SupportedModels.MediapipeBlazeposeUpperBody ?
127+
EXPECTED_UPPERBODY_LANDMARKS :
128+
EXPECTED_FULLBODY_LANDMARKS;
127129
expectArraysClose(xy, expected, EPSILON_IMAGE);
128130

129131
expect(tf.memory().numTensors).toEqual(beforeTensors);
@@ -160,16 +162,18 @@ describeWithFlags('Blazepose video ', BROWSER_ENVS, () => {
160162
jasmine.DEFAULT_TIMEOUT_INTERVAL = timeout;
161163
});
162164

163-
UPPERBODY_ONLY.forEach(upperBodyOnly => {
165+
MODEL_LIST.forEach(model => {
164166
it('test.', async () => {
165167
// Note: this makes a network request for model assets.
166-
const modelConfig:
167-
poseDetection.BlazeposeModelConfig = {quantBytes: 4, upperBodyOnly};
168-
detector = await poseDetection.createDetector(
169-
poseDetection.SupportedModels.MediapipeBlazepose, modelConfig);
168+
169+
const modelConfig: poseDetection.BlazeposeModelConfig = {quantBytes: 4};
170+
detector = await poseDetection.createDetector(model, modelConfig);
170171

171172
const result: number[][][] = [];
172-
const expected = upperBodyOnly ? expectedUpperBody : expectedFullBody;
173+
const expected =
174+
model === poseDetection.SupportedModels.MediapipeBlazeposeUpperBody ?
175+
expectedUpperBody :
176+
expectedFullBody;
173177

174178
const callback = async(video: HTMLVideoElement, timestamp: number):
175179
Promise<poseDetection.Pose[]> => {
@@ -185,8 +189,7 @@ describeWithFlags('Blazepose video ', BROWSER_ENVS, () => {
185189
// `ffmpeg -i original_pose.mp4 -r 5 -vcodec libx264 -crf 28 -profile:v
186190
// baseline pose_squats.mp4`
187191
await loadVideo(
188-
'pose_squats.mp4', 5 /* fps */, callback, expected,
189-
poseDetection.SupportedModels.MediapipeBlazepose);
192+
'pose_squats.mp4', 5 /* fps */, callback, expected, model);
190193

191194
expectArraysClose(result, expected, EPSILON_VIDEO);
192195

pose-detection/src/blazepose/constants.ts

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
* =============================================================================
1616
*/
1717

18-
import {BlazeposeModelConfig} from './types';
19-
20-
export const DEFAULT_BLAZEPOSE_FULLBODY_CONFIG: BlazeposeModelConfig = {
21-
upperBodyOnly: false
22-
};
2318
export const DEFAULT_BLAZEPOSE_DETECTOR_MODEL_URL =
2419
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/detector/model.json';
2520
export const DEFAULT_BLAZEPOSE_LANDMARK_FULL_BODY_MODEL_URL =

pose-detection/src/blazepose/detector.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ type PoseLandmarkByRoiResult = {
5757
* Blazepose detector class.
5858
*/
5959
export class BlazeposeDetector extends BasePoseDetector {
60-
private readonly upperBodyOnly: boolean;
6160
private readonly anchors: Rect[];
6261
private readonly anchorTensor: AnchorTensor;
6362

@@ -75,11 +74,9 @@ export class BlazeposeDetector extends BasePoseDetector {
7574
private constructor(
7675
private readonly detectorModel: tfconv.GraphModel,
7776
private readonly landmarkModel: tfconv.GraphModel,
78-
config: BlazeposeModelConfig) {
77+
private readonly upperBodyOnly: boolean) {
7978
super();
8079

81-
this.upperBodyOnly = config.upperBodyOnly;
82-
8380
this.anchors =
8481
createSsdAnchors(constants.BLAZEPOSE_DETECTOR_ANCHOR_CONFIGURATION);
8582
const anchorW = tf.tensor1d(this.anchors.map(a => a.width));
@@ -98,15 +95,16 @@ export class BlazeposeDetector extends BasePoseDetector {
9895
* the Blazepose loading process. Please find more details of each parameters
9996
* in the documentation of the `BlazeposeModelConfig` interface.
10097
*/
101-
static async load(modelConfig: BlazeposeModelConfig): Promise<PoseDetector> {
102-
const config = validateModelConfig(modelConfig);
98+
static async load(modelConfig: BlazeposeModelConfig, upperBodyOnly = false):
99+
Promise<PoseDetector> {
100+
const config = validateModelConfig(modelConfig, upperBodyOnly);
103101

104102
const [detectorModel, landmarkModel] = await Promise.all([
105103
tfconv.loadGraphModel(config.detectorModelUrl),
106104
tfconv.loadGraphModel(config.landmarkModelUrl)
107105
]);
108106

109-
return new BlazeposeDetector(detectorModel, landmarkModel, config);
107+
return new BlazeposeDetector(detectorModel, landmarkModel, upperBodyOnly);
110108
}
111109

112110
/**

pose-detection/src/blazepose/detector_utils.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,20 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DEFAULT_BLAZEPOSE_DETECTOR_MODEL_URL, DEFAULT_BLAZEPOSE_ESTIMATION_CONFIG, DEFAULT_BLAZEPOSE_FULLBODY_CONFIG, DEFAULT_BLAZEPOSE_LANDMARK_FULL_BODY_MODEL_URL, DEFAULT_BLAZEPOSE_LANDMARK_UPPER_BODY_MODEL_URL} from './constants';
18+
import {DEFAULT_BLAZEPOSE_DETECTOR_MODEL_URL, DEFAULT_BLAZEPOSE_ESTIMATION_CONFIG, DEFAULT_BLAZEPOSE_LANDMARK_FULL_BODY_MODEL_URL, DEFAULT_BLAZEPOSE_LANDMARK_UPPER_BODY_MODEL_URL} from './constants';
1919
import {BlazeposeEstimationConfig, BlazeposeModelConfig} from './types';
2020

21-
export function validateModelConfig(modelConfig: BlazeposeModelConfig):
22-
BlazeposeModelConfig {
21+
export function validateModelConfig(
22+
modelConfig: BlazeposeModelConfig,
23+
upperBodyOnly: boolean): BlazeposeModelConfig {
2324
let config;
2425

2526
if (modelConfig == null) {
26-
config = {...DEFAULT_BLAZEPOSE_FULLBODY_CONFIG};
27+
config = {};
2728
} else {
2829
config = {...modelConfig};
2930
}
3031

31-
if (config.upperBodyOnly == null) {
32-
config.upperBodyOnly = false;
33-
}
34-
3532
if (config.quantBytes == null) {
3633
config.quantBytes = 4;
3734
}
@@ -41,7 +38,7 @@ export function validateModelConfig(modelConfig: BlazeposeModelConfig):
4138
}
4239

4340
if (config.landmarkModelUrl == null) {
44-
if (config.upperBodyOnly) {
41+
if (upperBodyOnly) {
4542
config.landmarkModelUrl = DEFAULT_BLAZEPOSE_LANDMARK_UPPER_BODY_MODEL_URL;
4643
} else {
4744
config.landmarkModelUrl = DEFAULT_BLAZEPOSE_LANDMARK_FULL_BODY_MODEL_URL;

pose-detection/src/blazepose/types.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ import {EstimationConfig, ModelConfig} from '../types';
2020
/**
2121
* Additional Blazepose model loading config.
2222
*
23-
* `upperBodyOnly`: Optional. Default to false. If upperBody is true, then the
24-
* detector only detects 25 keypoints of the upperbody. The upperbody model
25-
* has a higher accuracy for upperbody prediction than the fullbody model. If
26-
* upperBody is false, then the detector detects 33 keypoints of the full body.
27-
*
2823
* `detectorModelUrl`: Optional. An optional string that specifies custom url of
2924
* the detector model. This is useful for area/countries that don't have access
3025
* to the model hosted on GCP.
@@ -34,7 +29,6 @@ import {EstimationConfig, ModelConfig} from '../types';
3429
* to the model hosted on GCP.
3530
*/
3631
export interface BlazeposeModelConfig extends ModelConfig {
37-
upperBodyOnly?: boolean;
3832
lite?: boolean;
3933
detectorModelUrl?: string;
4034
landmarkModelUrl?: string;

pose-detection/src/constants.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export const COCO_KEYPOINTS = [
1919
'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist',
2020
'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle'
2121
];
22-
export const BLAZEPOSE_KEYPOINTS_UPPERBODY = [
22+
export const BLAZEPOSE_KEYPOINTS_UPPER_BODY = [
2323
'nose', 'rightEyeInner', 'rightEye', 'rightEyeOuter',
2424
'leftEyeInner', 'leftEye', 'leftEyeOuter', 'rightEar',
2525
'leftEar', 'mouthRight', 'mouthLeft', 'rightShoulder',
@@ -28,15 +28,20 @@ export const BLAZEPOSE_KEYPOINTS_UPPERBODY = [
2828
'leftIndex', 'rightThumb', 'leftThumb', 'rightHip',
2929
'leftHip'
3030
];
31-
export const BLAZEPOSE_KEYPOINTS_FULLBODY = [
32-
...BLAZEPOSE_KEYPOINTS_UPPERBODY, 'rightKnee', 'leftKnee', 'rightAnkle',
31+
export const BLAZEPOSE_KEYPOINTS_FULL_BODY = [
32+
...BLAZEPOSE_KEYPOINTS_UPPER_BODY, 'rightKnee', 'leftKnee', 'rightAnkle',
3333
'leftAnkle', 'rightHeel', 'leftHeel', 'rightFoot', 'leftFoot'
3434
];
35-
export const BLAZEPOSE_KEYPOINTS_BY_SIDE = {
35+
export const BLAZEPOSE_KEYPOINTS_BY_SIDE_FULL_BODY = {
3636
left: [4, 5, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32],
3737
right: [1, 2, 3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31],
3838
middle: [0]
3939
};
40+
export const BLAZEPOSE_KEYPOINTS_BY_SIDE_UPPER_BODY = {
41+
left: [4, 5, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24],
42+
right: [1, 2, 3, 7, 9, 11, 13, 15, 17, 19, 21, 23],
43+
middle: [0]
44+
};
4045
export const COCO_KEYPOINTS_BY_SIDE = {
4146
left: [1, 3, 5, 7, 9, 11, 13, 15],
4247
right: [2, 4, 6, 8, 10, 12, 14, 16],
@@ -46,14 +51,20 @@ export const COCO_CONNECTED_KEYPOINTS_PAIRS = [
4651
[0, 1], [0, 2], [1, 3], [2, 4], [5, 6], [5, 7], [5, 11], [6, 8], [6, 12],
4752
[7, 9], [8, 10], [11, 12], [11, 13], [12, 14], [13, 15], [14, 16]
4853
];
49-
export const BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS = [
54+
export const BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS_FULL_BODY = [
5055
[0, 1], [0, 4], [1, 2], [2, 3], [3, 7], [4, 5],
5156
[5, 6], [6, 8], [9, 10], [11, 12], [11, 13], [11, 23],
5257
[12, 14], [14, 16], [12, 24], [13, 15], [15, 17], [16, 18],
5358
[16, 20], [15, 17], [15, 19], [15, 21], [16, 22], [17, 19],
5459
[18, 20], [23, 25], [23, 24], [24, 26], [25, 27], [26, 28],
5560
[27, 29], [28, 30], [27, 31], [28, 32], [29, 31], [30, 32]
5661
];
62+
export const BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS_UPPER_BODY = [
63+
[0, 1], [0, 4], [1, 2], [2, 3], [3, 7], [4, 5], [5, 6],
64+
[6, 8], [9, 10], [11, 12], [11, 13], [11, 23], [12, 14], [14, 16],
65+
[12, 24], [13, 15], [15, 17], [16, 18], [16, 20], [15, 17], [15, 19],
66+
[15, 21], [16, 22], [17, 19], [18, 20], [23, 24]
67+
];
5768
export const COCO_KEYPOINTS_NAMED_MAP: {[index: string]: number} = {
5869
nose: 0,
5970
left_eye: 1,

pose-detection/src/create_detector.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ export async function createDetector(
3636
switch (model) {
3737
case SupportedModels.PoseNet:
3838
return PosenetDetector.load(modelConfig as PosenetModelConfig);
39-
case SupportedModels.MediapipeBlazepose:
39+
case SupportedModels.MediapipeBlazeposeUpperBody:
40+
return BlazeposeDetector.load(
41+
modelConfig as BlazeposeModelConfig, true /* upperBodyOnly */);
42+
case SupportedModels.MediapipeBlazeposeFullBody:
4043
return BlazeposeDetector.load(modelConfig as BlazeposeModelConfig);
4144
case SupportedModels.MoveNet:
4245
return MoveNetDetector.load(modelConfig as MoveNetModelConfig);

pose-detection/src/test_util.ts

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,13 @@ function drawSkeleton(
136136
ctx.lineWidth = 2;
137137

138138
poseDetection.util.getAdjacentPairs(model).forEach(([i, j]) => {
139-
if (i < keypoints.length && j < keypoints.length) {
140-
const kp1 = keypoints[i];
141-
const kp2 = keypoints[j];
142-
143-
ctx.beginPath();
144-
ctx.moveTo(kp1.x, kp1.y);
145-
ctx.lineTo(kp2.x, kp2.y);
146-
ctx.stroke();
147-
}
139+
const kp1 = keypoints[i];
140+
const kp2 = keypoints[j];
141+
142+
ctx.beginPath();
143+
ctx.moveTo(kp1.x, kp1.y);
144+
ctx.lineTo(kp2.x, kp2.y);
145+
ctx.stroke();
148146
});
149147
}
150148

pose-detection/src/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ import * as tf from '@tensorflow/tfjs-core';
1818

1919
export enum SupportedModels {
2020
MoveNet = 'MoveNet',
21-
MediapipeBlazepose = 'MediapipeBlazepose',
21+
MediapipeBlazeposeFullBody = 'MediapipeBlazeposeFullBody',
22+
MediapipeBlazeposeUpperBody = 'MediapipeBlazeposeUpperBody',
2223
PoseNet = 'PoseNet'
2324
}
2425

pose-detection/src/util.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import {SupportedModels} from './types';
2020
export function getKeypointIndexBySide(model: SupportedModels):
2121
{left: number[], right: number[], middle: number[]} {
2222
switch (model) {
23-
case SupportedModels.MediapipeBlazepose:
24-
return constants.BLAZEPOSE_KEYPOINTS_BY_SIDE;
23+
case SupportedModels.MediapipeBlazeposeUpperBody:
24+
return constants.BLAZEPOSE_KEYPOINTS_BY_SIDE_UPPER_BODY;
25+
case SupportedModels.MediapipeBlazeposeFullBody:
26+
return constants.BLAZEPOSE_KEYPOINTS_BY_SIDE_FULL_BODY;
2527
case SupportedModels.PoseNet:
2628
case SupportedModels.MoveNet:
2729
return constants.COCO_KEYPOINTS_BY_SIDE;
@@ -31,8 +33,10 @@ export function getKeypointIndexBySide(model: SupportedModels):
3133
}
3234
export function getAdjacentPairs(model: SupportedModels): number[][] {
3335
switch (model) {
34-
case SupportedModels.MediapipeBlazepose:
35-
return constants.BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS;
36+
case SupportedModels.MediapipeBlazeposeUpperBody:
37+
return constants.BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS_UPPER_BODY;
38+
case SupportedModels.MediapipeBlazeposeFullBody:
39+
return constants.BLAZEPOSE_CONNECTED_KEYPOINTS_PAIRS_FULL_BODY;
3640
case SupportedModels.PoseNet:
3741
case SupportedModels.MoveNet:
3842
return constants.COCO_CONNECTED_KEYPOINTS_PAIRS;

0 commit comments

Comments
 (0)