Skip to content

Commit 6b458da

Browse files
authored
[pose-detection]Use 3d model for BlazePose. (#762)
FEATURE
1 parent 5598824 commit 6b458da

File tree

15 files changed

+9878
-883
lines changed

15 files changed

+9878
-883
lines changed

pose-detection/demos/live_video/index.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
position: relative;
2828
margin: 0;
2929
}
30-
canvas {
31-
position: absolute;
32-
top: 0;
33-
left: 0;
30+
#canvas-wrapper,
31+
#scatter-gl-container {
32+
position: relative;
3433
}
3534
</style>
3635
</head>
@@ -49,6 +48,7 @@
4948
">
5049
</video>
5150
</div>
51+
<div id="scatter-gl-container"></div>
5252
</div>
5353
</div>
5454
</body>

pose-detection/demos/live_video/package-lock.json

Lines changed: 8855 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pose-detection/demos/live_video/package.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12+
"@mediapipe/pose": "~0.4.0",
1213
"@tensorflow-models/pose-detection": "file:../../dist",
13-
"@tensorflow/tfjs-backend-wasm": "^3.6.0",
14-
"@tensorflow/tfjs-backend-webgl": "^3.6.0",
15-
"@tensorflow/tfjs-converter": "^3.6.0",
16-
"@tensorflow/tfjs-core": "^3.6.0",
17-
"@mediapipe/pose": "~0.3.0"
14+
"@tensorflow/tfjs-backend-wasm": "^3.8.0",
15+
"@tensorflow/tfjs-backend-webgl": "^3.8.0",
16+
"@tensorflow/tfjs-converter": "^3.8.0",
17+
"@tensorflow/tfjs-core": "^3.8.0",
18+
"scatter-gl": "0.0.8"
1819
},
1920
"scripts": {
2021
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open",

pose-detection/demos/live_video/src/camera.js

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,26 @@
1515
* =============================================================================
1616
*/
1717
import * as posedetection from '@tensorflow-models/pose-detection';
18+
import * as scatter from 'scatter-gl';
1819

1920
import * as params from './params';
2021
import {isMobile} from './util';
2122

23+
// These anchor points allow the pose pointcloud to resize according to its
24+
// position in the input.
25+
const ANCHOR_POINTS = [[0, 0, 0], [0, 1, 0], [-1, 0, 0], [-1, -1, 0]];
2226
export class Camera {
2327
constructor() {
2428
this.video = document.getElementById('video');
2529
this.canvas = document.getElementById('output');
2630
this.ctx = this.canvas.getContext('2d');
31+
this.scatterGLEl = document.querySelector('#scatter-gl-container');
32+
this.scatterGL = new scatter.ScatterGL(this.scatterGLEl, {
33+
'rotateOnStart': true,
34+
'selectEnabled': false,
35+
'styles': {polyline: {defaultOpacity: 1, deselectedOpacity: 1}}
36+
});
37+
this.scatterGLHasInitialized = false;
2738
}
2839

2940
/**
@@ -81,6 +92,13 @@ export class Camera {
8192
camera.ctx.translate(camera.video.videoWidth, 0);
8293
camera.ctx.scale(-1, 1);
8394

95+
camera.scatterGLEl.style =
96+
`width: ${videoWidth}px; height: ${videoHeight}px;`;
97+
camera.scatterGL.resize();
98+
99+
camera.scatterGLEl.style.display =
100+
params.STATE.modelConfig.render3D ? 'inline-block' : 'none';
101+
84102
return camera;
85103
}
86104

@@ -112,6 +130,9 @@ export class Camera {
112130
this.drawKeypoints(pose.keypoints);
113131
this.drawSkeleton(pose.keypoints);
114132
}
133+
if (pose.keypoints3D != null && params.STATE.modelConfig.render3D) {
134+
this.drawKeypoints3D(pose.keypoints3D);
135+
}
115136
}
116137

117138
/**
@@ -121,7 +142,7 @@ export class Camera {
121142
drawKeypoints(keypoints) {
122143
const keypointInd =
123144
posedetection.util.getKeypointIndexBySide(params.STATE.model);
124-
this.ctx.fillStyle = 'White';
145+
this.ctx.fillStyle = 'Red';
125146
this.ctx.strokeStyle = 'White';
126147
this.ctx.lineWidth = params.DEFAULT_LINE_WIDTH;
127148

@@ -181,4 +202,41 @@ export class Camera {
181202
}
182203
});
183204
}
205+
206+
drawKeypoints3D(keypoints) {
207+
const scoreThreshold = params.STATE.modelConfig.scoreThreshold || 0;
208+
const pointsData =
209+
keypoints.map(keypoint => ([-keypoint.x, -keypoint.y, -keypoint.z]));
210+
211+
const dataset =
212+
new scatter.ScatterGL.Dataset([...pointsData, ...ANCHOR_POINTS]);
213+
214+
const keypointInd =
215+
posedetection.util.getKeypointIndexBySide(params.STATE.model);
216+
this.scatterGL.setPointColorer((i) => {
217+
if (keypoints[i] == null || keypoints[i].score < scoreThreshold) {
218+
// hide anchor points and low-confident points.
219+
return '#ffffff';
220+
}
221+
if (i === 0) {
222+
return '#ff0000' /* Red */;
223+
}
224+
if (keypointInd.left.indexOf(i) > -1) {
225+
return '#00ff00' /* Green */;
226+
}
227+
if (keypointInd.right.indexOf(i) > -1) {
228+
return '#ffa500' /* Orange */;
229+
}
230+
});
231+
232+
if (!this.scatterGLHasInitialized) {
233+
this.scatterGL.render(dataset);
234+
} else {
235+
this.scatterGL.updateDataset(dataset);
236+
}
237+
const connections = posedetection.util.getAdjacentPairs(params.STATE.model);
238+
const sequences = connections.map(pair => ({indices: pair}));
239+
this.scatterGL.setSequences(sequences);
240+
this.scatterGLHasInitialized = true;
241+
}
184242
}

pose-detection/demos/live_video/src/index.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ async function createDetector() {
5353
return posedetection.createDetector(STATE.model, {
5454
runtime,
5555
modelType: STATE.modelConfig.type,
56-
solutionPath:
57-
'https://cdn.jsdelivr.net/npm/@mediapipe/[email protected]'
56+
solutionPath: 'https://cdn.jsdelivr.net/npm/@mediapipe/[email protected]'
5857
});
5958
} else if (runtime === 'tfjs') {
6059
return posedetection.createDetector(

pose-detection/demos/live_video/src/option_panel.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ function addBlazePoseControllers(modelConfigFolder, type) {
189189
});
190190

191191
modelConfigFolder.add(params.STATE.modelConfig, 'scoreThreshold', 0, 1);
192+
193+
const render3DController =
194+
modelConfigFolder.add(params.STATE.modelConfig, 'render3D');
195+
render3DController.onChange(render3D => {
196+
document.querySelector('#scatter-gl-container').style.display =
197+
render3D ? 'inline-block' : 'none';
198+
});
192199
}
193200

194201
/**

pose-detection/demos/live_video/src/params.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ export const STATE = {
3333
};
3434
export const BLAZEPOSE_CONFIG = {
3535
maxPoses: 1,
36-
scoreThreshold: 0.65
36+
type: 'full',
37+
scoreThreshold: 0.65,
38+
render3D: true
3739
};
3840
export const POSENET_CONFIG = {
3941
maxPoses: 1,
@@ -83,8 +85,7 @@ export const BACKEND_FLAGS_MAP = {
8385
export const MODEL_BACKEND_MAP = {
8486
[posedetection.SupportedModels.PoseNet]: ['tfjs-webgl'],
8587
[posedetection.SupportedModels.MoveNet]: ['tfjs-webgl', 'tfjs-wasm'],
86-
[posedetection.SupportedModels.BlazePose]:
87-
isiOS() ? ['tfjs-webgl'] : ['mediapipe-gpu', 'tfjs-webgl']
88+
[posedetection.SupportedModels.BlazePose]: ['mediapipe-gpu', 'tfjs-webgl']
8889
}
8990

9091
export const TUNABLE_FLAG_NAME_MAP = {

0 commit comments

Comments
 (0)