Skip to content

Commit 841afd4

Browse files
authored
[pose-detection]Initial blazepose checkin. (#616)
FEATURE
1 parent ad95c70 commit 841afd4

33 files changed

+1944
-39
lines changed

pose-detection/demo/src/camera.js

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,25 @@ export class Camera {
8989
this.ctx.clearRect(0, 0, this.video.videoWidth, this.video.videoHeight);
9090
}
9191

92-
drawResult(pose) {
93-
this.drawKeypoints(pose.keypoints);
92+
drawResult(pose, shouldScale = false) {
93+
this.drawKeypoints(pose.keypoints, shouldScale);
9494
}
9595

96-
drawKeypoints(keypoints) {
96+
/**
97+
* Draw the keypoints on the video.
98+
* @param keypoints A list of keypoints, may be normalized.
99+
* @param shouldScale If the keypoints are normalized, shouldScale should be
100+
* set to true.
101+
*/
102+
drawKeypoints(keypoints, shouldScale) {
103+
const scaleX = shouldScale ? this.video.videoWidth : 1;
104+
const scaleY = shouldScale ? this.video.videoHeight : 1;
97105
this.ctx.fillStyle = 'red';
98106
this.ctx.strokeStyle = 'white';
99107
this.ctx.lineWidth = 4;
100108
keypoints.forEach(keypoint => {
101109
const circle = new Path2D();
102-
circle.arc(keypoint.x, keypoint.y, 4, 0, 2 * Math.PI);
110+
circle.arc(keypoint.x * scaleX, keypoint.y * scaleY, 4, 0, 2 * Math.PI);
103111
this.ctx.fill(circle);
104112
this.ctx.stroke(circle);
105113
});

pose-detection/demo/src/index.js

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ import {setupStats} from './stats_panel';
2727

2828
let detector, camera, stats;
2929

30+
async function createDetector(model) {
31+
switch (model) {
32+
case posedetection.SupportedModels.PoseNet:
33+
return posedetection.createDetector(STATE.model.model, {
34+
quantBytes: 4,
35+
architecture: 'MobileNetV1',
36+
outputStride: 16,
37+
inputResolution: {width: 500, height: 500},
38+
multiplier: 0.75
39+
});
40+
case posedetection.SupportedModels.MediapipeBlazepose:
41+
return posedetection.createDetector(
42+
STATE.model.model, {quantBytes: 4, upperBodyOnly: false});
43+
}
44+
}
45+
3046
async function checkGuiUpdate() {
3147
if (STATE.changeToTargetFPS || STATE.changeToSizeOption) {
3248
if (STATE.changeToTargetFPS) {
@@ -42,6 +58,14 @@ async function checkGuiUpdate() {
4258
camera = await Camera.setupCamera(STATE.camera);
4359
}
4460

61+
if (STATE.changeToModel) {
62+
STATE.model.model = STATE.changeToModel;
63+
STATE.changeToModel = null;
64+
65+
detector.dispose();
66+
detector = await createDetector(STATE.model.model);
67+
}
68+
4569
await tf.nextFrame();
4670
}
4771

@@ -50,12 +74,15 @@ async function renderResult() {
5074
camera.lastVideoTime = camera.video.currentTime;
5175

5276
const poses = await detector.estimatePoses(
53-
video, {maxPoses: 1, flipHorizontal: false});
77+
camera.video, {maxPoses: 1, flipHorizontal: false});
5478

5579
camera.drawCtx();
5680

5781
if (poses.length > 0) {
58-
camera.drawResult(poses[0]);
82+
const shouldScale = STATE.model.model ===
83+
posedetection.SupportedModels.MediapipeBlazepose;
84+
85+
camera.drawResult(poses[0], shouldScale);
5986
}
6087
}
6188
}
@@ -78,14 +105,7 @@ async function app() {
78105
stats = setupStats();
79106
camera = await Camera.setupCamera(STATE.camera);
80107

81-
detector = await posedetection.createDetector(
82-
posedetection.SupportedModels.PoseNet, {
83-
quantBytes: 4,
84-
architecture: 'MobileNetV1',
85-
outputStride: 16,
86-
inputResolution: {width: 500, height: 500},
87-
multiplier: 0.75
88-
});
108+
detector = await createDetector(STATE.model.model);
89109

90110
renderPrediction();
91111
};

pose-detection/demo/src/option_panel.js

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17+
import * as posedetection from '@tensorflow-models/posedetection';
18+
1719
import {STATE, VIDEO_SIZE} from './params';
1820

1921
export function setupDatGui() {
@@ -32,5 +34,14 @@ export function setupDatGui() {
3234
});
3335
cameraFolder.open();
3436

37+
// The model folder contains options for model settings.
38+
const modelFolder = gui.addFolder('Model');
39+
const modelController = modelFolder.add(
40+
STATE.model, 'model', Object.values(posedetection.SupportedModels));
41+
modelController.onChange(model => {
42+
STATE.changeToModel = model;
43+
});
44+
modelFolder.open();
45+
3546
return gui;
3647
}

pose-detection/demo/src/params.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17+
import * as posedetection from '@tensorflow-models/posedetection';
18+
1719
export const DEFAULT_LINE_WIDTH = 4;
1820

1921
export const VIDEO_SIZE = {
2022
'640 X 480': {width: 640, height: 480},
2123
'640 X 360': {width: 640, height: 360}
2224
};
2325
export const STATE = {
24-
camera: {targetFPS: 60, sizeOption: '640 X 480'}
26+
camera: {targetFPS: 60, sizeOption: '640 X 480'},
27+
model: {model: posedetection.SupportedModels.PoseNet}
2528
};
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs-core';
19+
// tslint:disable-next-line: no-imports-from-dist
20+
import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';
21+
22+
import * as poseDetection from '../index';
23+
24+
describeWithFlags('Blazepose', ALL_ENVS, () => {
25+
let detector: poseDetection.PoseDetector;
26+
let startTensors: number;
27+
28+
beforeEach(async () => {
29+
startTensors = tf.memory().numTensors;
30+
31+
// Note: this makes a network request for model assets.
32+
const modelConfig: poseDetection.BlazeposeModelConfig = {
33+
quantBytes: 4,
34+
upperBodyOnly: false
35+
};
36+
detector = await poseDetection.createDetector(
37+
poseDetection.SupportedModels.MediapipeBlazepose, modelConfig);
38+
});
39+
40+
it('estimatePoses does not leak memory', async () => {
41+
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);
42+
43+
const beforeTensors = tf.memory().numTensors;
44+
45+
await detector.estimatePoses(input);
46+
47+
expect(tf.memory().numTensors).toEqual(beforeTensors);
48+
49+
detector.dispose();
50+
input.dispose();
51+
52+
expect(tf.memory().numTensors).toEqual(startTensors);
53+
});
54+
});
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ImageSize} from '../../calculators/interfaces/common_interfaces';
19+
import {Rect} from '../../calculators/interfaces/shape_interfaces';
20+
import {computeRotation} from './detection_to_rect';
21+
import {DetectionToRectConfig} from './interfaces/config_interfaces';
22+
import {Detection} from './interfaces/shape_interfaces';
23+
24+
// ref:
25+
// https://github.com/google/mediapipe/blob/master/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc
26+
export function calculateAlignmentPointsRects(
27+
detection: Detection, imageSize: ImageSize,
28+
config: DetectionToRectConfig): Rect {
29+
const startKeypoint = config.rotationVectorStartKeypointIndex;
30+
const endKeypoint = config.rotationVectorEndKeypointIndex;
31+
32+
33+
const locationData = detection.locationData;
34+
const xCenter =
35+
locationData.relativeKeypoints[startKeypoint].x * imageSize.width;
36+
const yCenter =
37+
locationData.relativeKeypoints[startKeypoint].y * imageSize.height;
38+
39+
const xScale =
40+
locationData.relativeKeypoints[endKeypoint].x * imageSize.width;
41+
const yScale =
42+
locationData.relativeKeypoints[endKeypoint].y * imageSize.height;
43+
44+
// Bounding box size as double distance from center to scale point.
45+
const boxSize = Math.sqrt(
46+
(xScale - xCenter) * (xScale - xCenter) +
47+
(yScale - yCenter) * (yScale - yCenter)) *
48+
2;
49+
50+
const rotation = computeRotation(detection, imageSize, config);
51+
52+
// Set resulting bounding box.
53+
return {
54+
xCenter: xCenter / imageSize.width,
55+
yCenter: yCenter / imageSize.height,
56+
width: boxSize / imageSize.width,
57+
height: boxSize / imageSize.height,
58+
rotation
59+
};
60+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Rect} from '../../calculators/interfaces/shape_interfaces';
19+
import {Keypoint} from '../../types';
20+
21+
/**
22+
* Projects normalized landmarks in a rectangle to its original coordinates. The
23+
* rectangle must also be in normalized coordinates.
24+
* @param landmarks A normalized Landmark list representing landmarks in a
25+
* normalized rectangle.
26+
* @param inputRect A normalized rectangle.
27+
* @param config Config object has one field ignoreRotation, default to false.
28+
*/
29+
// ref:
30+
// https://github.com/google/mediapipe/blob/master/mediapipe/calculators/util/landmark_projection_calculator.cc
31+
export function calculateLandmarkProjection(
32+
landmarks: Keypoint[], inputRect: Rect,
33+
config: {ignoreRotation: boolean} = {
34+
ignoreRotation: false
35+
}) {
36+
const outputLandmarks = [];
37+
for (const landmark of landmarks) {
38+
const x = landmark.x - 0.5;
39+
const y = landmark.y - 0.5;
40+
const angle = config.ignoreRotation ? 0 : inputRect.rotation;
41+
let newX = Math.cos(angle) * x - Math.sin(angle) * y;
42+
let newY = Math.sin(angle) * x + Math.cos(angle) * y;
43+
44+
newX = newX * inputRect.width + inputRect.xCenter;
45+
newY = newY * inputRect.height + inputRect.yCenter;
46+
47+
const newZ = landmark.z * inputRect.width; // Scale Z coordinate as x.
48+
49+
const newLandmark = {...landmark};
50+
51+
newLandmark.x = newX;
52+
newLandmark.y = newY;
53+
newLandmark.z = newZ;
54+
55+
outputLandmarks.push(newLandmark);
56+
}
57+
58+
return outputLandmarks;
59+
}

0 commit comments

Comments
 (0)