Skip to content

Commit 4b5041d

Browse files
authored
Introducing KeypointTracker, which associated people between frames using keypoint similarity. (#769)
* Skeleton for abstract base class Tracker * Creating updateTracks() which cleans up track store based on track age. Also, updating property names to be mixed case instead of using underscores * Basic formatting updates. * Adding lowerCamelCase for all variables and a couple more formatting updates * Creating tracker_utils.ts to validate tracker config. * Making computeSimilarity() abstract to allow for batch implementations. Also fixing some comments and validation checks * Updating flow so that tracks are filtered by age prior to assignment with new detections * Updating the documentation for maxTracks to indicate that this should be set higher than maxPoses * Moving track filtering before the computation of similarity matrices * Implementing assignTracks(), which is a greedy assigner. * Merge remote-tracking branch 'upstream/master' into tracker, and add new tracking changes. * Fixing linting issues. * Updating assignTracks so that unmatched detections are kept track of * Merge remote-tracking branch 'upstream/master' into tracker * Introducing KeypointTracker, which tracks people from frame-to-frame based on keypoint similarity. * Extending keypoint tracking test to show the enforcement of maxTracks. * Fixing linting issues. * Small surgical fixes. * Switching condition in computeSimilarity() so that both poses and tracks must be present for similarity matrix computation. * Making KeypointTracker area() and oks() methods private. * Formatting update. * Comment formatting update.
1 parent 529fe18 commit 4b5041d

File tree

5 files changed

+398
-3
lines changed

5 files changed

+398
-3
lines changed

pose-detection/src/calculators/interfaces/config_interfaces.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,25 @@ export interface TrackerConfig {
8383
// re-identifications).
8484
minSimilarity: number; // New poses will only be linked with tracks if the
8585
// similarity score exceeds this threshold.
86+
trackerParams: KeypointTrackerConfig; // Config for tracker. Note that as
87+
// more trackers are implemented, this
88+
// should become a union of all tracker
89+
// types.
90+
}
91+
// A tracker that links detections (i.e. poses) and tracks based on keypoint
92+
// similarity.
93+
export interface KeypointTrackerConfig {
94+
keypointConfidenceThreshold: number; // The minimum keypoint confidence
95+
// threshold. A keypoint is only
96+
// compared in the OKS calculation if
97+
// both the new detected keypoint and
98+
// the corresponding track keypoint have
99+
// confidences above this threshold.
100+
101+
keypointFalloff: number[]; // Per-keypoint falloff in OKS calculation.
102+
minNumberOfKeypoints: number; // The minimum number of keypoints that are
103+
// necessary for computing OKS. If the number
104+
// of confident keypoints (between a pose and
105+
// track) are under this value, an OKS of 0.0
106+
// will be given.
86107
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Keypoint, Pose} from '../types';
19+
import {Tracker} from './tracker';
20+
import {Track} from './interfaces/common_interfaces';
21+
import {TrackerConfig} from './interfaces/config_interfaces';
22+
23+
/**
24+
* KeypointTracker, which tracks poses based on keypoint similarity. This
25+
* tracker assumes that keypoints are provided in normalized image
26+
* coordinates.
27+
*/
28+
export class KeypointTracker extends Tracker {
29+
private readonly keypointThreshold: number;
30+
private readonly keypointFalloff: number[];
31+
private readonly minNumKeyoints: number;
32+
33+
constructor(config: TrackerConfig) {
34+
super(config);
35+
//TODO(ronnyvotel): validate.
36+
this.keypointThreshold = config.trackerParams.keypointConfidenceThreshold;
37+
this.keypointFalloff = config.trackerParams.keypointFalloff;
38+
this.minNumKeyoints = config.trackerParams.minNumberOfKeypoints;
39+
}
40+
41+
/**
42+
* Computes similarity based on Object Keypoint Similarity (OKS). It's
43+
* assumed that the keypoints within each `Pose` are in normalized image
44+
* coordinates. See `Tracker` for more details.
45+
*/
46+
computeSimilarity(poses: Pose[]): number[][] {
47+
if (poses.length === 0 || this.tracks.length === 0) {
48+
return [[]];
49+
}
50+
51+
const simMatrix = [];
52+
for (const pose of poses) {
53+
const row = [];
54+
for (const track of this.tracks) {
55+
row.push(this.oks(pose, track));
56+
}
57+
simMatrix.push(row);
58+
}
59+
return simMatrix;
60+
}
61+
62+
/**
63+
* Computes the Object Keypoint Similarity (OKS) between a pose and track.
64+
* This is similar in spirit to the calculation used by COCO keypoint eval:
65+
* https://cocodataset.org/#keypoints-eval
66+
* In this case, OKS is calculated as:
67+
* (1/sum_i d(c_i, c_ti)) * \sum_i exp(-d_i^2/(2*a_ti*x_i^2))*d(c_i, c_ti)
68+
* where
69+
* d(x, y) is an indicator function which only produces 1 if x and y
70+
* exceed a given threshold (i.e. keypointThreshold), otherwise 0.
71+
* c_i is the confidence of keypoint i from the new pose
72+
* c_ti is the confidence of keypoint i from the track
73+
* d_i is the Euclidean distance between the pose and track keypoint
74+
* a_ti is the area of the track object (the box covering the keypoints)
75+
* x_i is a constant that controls falloff in a Gaussian distribution,
76+
* computed as 2*keypointFalloff[i].
77+
* @param pose A `Pose`.
78+
* @param track A `Track`.
79+
* @returns The OKS score between the pose and the track. This number is
80+
* between 0 and 1, and larger values indicate more keypoint similarity.
81+
*/
82+
private oks(pose: Pose, track: Track): number {
83+
const boxArea = this.area(track.keypoints) + 1e-6;
84+
let oksTotal = 0;
85+
let numValidKeypoints = 0;
86+
for (let i = 0; i < pose.keypoints.length; ++i) {
87+
const poseKpt = pose.keypoints[i];
88+
const trackKpt = track.keypoints[i];
89+
if (poseKpt.score < this.keypointThreshold ||
90+
trackKpt.score < this.keypointThreshold) {
91+
continue;
92+
}
93+
numValidKeypoints += 1;
94+
const dSquared =
95+
Math.pow(poseKpt.x - trackKpt.x, 2) +
96+
Math.pow(poseKpt.y - trackKpt.y, 2);
97+
const x = 2 * this.keypointFalloff[i];
98+
oksTotal += Math.exp(-1 * dSquared / (2 * boxArea * x**2));
99+
}
100+
if (numValidKeypoints < this.minNumKeyoints) {
101+
return 0.0;
102+
}
103+
return oksTotal / numValidKeypoints;
104+
}
105+
106+
/**
107+
* Computes the area of a bounding box that tightly covers keypoints.
108+
* @param Keypoint[] An array of `Keypoint`s.
109+
* @returns The area of the object.
110+
*/
111+
private area(keypoints: Keypoint[]): number {
112+
const validKeypoint = keypoints.filter(
113+
kpt => kpt.score > this.keypointThreshold);
114+
const minX = Math.min(1.0, ...validKeypoint.map(kpt => kpt.x));
115+
const maxX = Math.max(0.0, ...validKeypoint.map(kpt => kpt.x));
116+
const minY = Math.min(1.0, ...validKeypoint.map(kpt => kpt.y));
117+
const maxY = Math.max(0.0, ...validKeypoint.map(kpt => kpt.y));
118+
return (maxX - minX) * (maxY - minY);
119+
}
120+
}
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {Keypoint, Pose} from '../types';
18+
import {KeypointTracker} from './keypoint_tracker';
19+
import {Track} from './interfaces/common_interfaces';
20+
import {TrackerConfig} from './interfaces/config_interfaces';
21+
22+
describe('Keypoint tracker', () => {
23+
const trackerConfig: TrackerConfig = {
24+
maxTracks: 4,
25+
maxAge: 1000,
26+
minSimilarity: 0.5,
27+
trackerParams: {
28+
keypointConfidenceThreshold: 0.2,
29+
keypointFalloff: [0.1, 0.1, 0.1, 0.1],
30+
minNumberOfKeypoints: 2
31+
}
32+
};
33+
34+
it('Instantiate tracker', () => {
35+
const kptTracker = new KeypointTracker(trackerConfig);
36+
expect(kptTracker instanceof KeypointTracker).toBe(true);
37+
});
38+
39+
it('Compute OKS', () => {
40+
const kptTracker = new KeypointTracker(trackerConfig);
41+
const pose: Pose = {
42+
keypoints: [
43+
{x: 0.2, y: 0.2, score: 1.0},
44+
{x: 0.4, y: 0.4, score: 0.8},
45+
{x: 0.6, y: 0.6, score: 0.1}, // Low confidence.
46+
{x: 0.8, y: 0.7, score: 0.8}
47+
]};
48+
const track: Track = {
49+
id: 0,
50+
lastTimestamp: 1000,
51+
keypoints: [
52+
{x: 0.2, y: 0.2, score: 1.0},
53+
{x: 0.4, y: 0.4, score: 0.8},
54+
{x: 0.6, y: 0.6, score: 0.9},
55+
{x: 0.8, y: 0.8, score: 0.8}
56+
]};
57+
const oks = kptTracker['oks'](pose, track);
58+
59+
const boxArea = (0.8 - 0.2) * (0.8 - 0.2);
60+
const x = 2*trackerConfig.trackerParams.keypointFalloff[3];
61+
const d = 0.1;
62+
const expectedOks =
63+
(1 + 1 + Math.exp(-1*d**2/(2*boxArea*x**2))) / 3;
64+
expect(oks).toBeCloseTo(expectedOks, 6);
65+
});
66+
67+
it('Compute OKS returns 0.0 with less than 2 valid keypoints', () => {
68+
const kptTracker = new KeypointTracker(trackerConfig);
69+
const pose: Pose = {
70+
keypoints: [
71+
{x: 0.2, y: 0.2, score: 1.0},
72+
{x: 0.4, y: 0.4, score: 0.1}, // Low confidence.
73+
{x: 0.6, y: 0.6, score: 0.9},
74+
{x: 0.8, y: 0.8, score: 0.8}
75+
]};
76+
const track: Track = {
77+
id: 0,
78+
lastTimestamp: 1000,
79+
keypoints: [
80+
{x: 0.2, y: 0.2, score: 1.0},
81+
{x: 0.4, y: 0.4, score: 0.8},
82+
{x: 0.6, y: 0.6, score: 0.1}, // Low confidence.
83+
{x: 0.8, y: 0.8, score: 0.0} // Low confidence.
84+
]};
85+
const oks = kptTracker['oks'](pose, track);
86+
expect(oks).toBeCloseTo(0.0, 6);
87+
});
88+
89+
it('Compute area', () => {
90+
const kptTracker = new KeypointTracker(trackerConfig);
91+
const keypoints: Keypoint[] = [
92+
{x: 0.1, y: 0.2, score: 1.0},
93+
{x: 0.3, y: 0.4, score: 0.9},
94+
{x: 0.4, y: 0.6, score: 0.9},
95+
{x: 0.7, y: 0.8, score: 0.1} // Low confidence.
96+
];
97+
const area = kptTracker['area'](keypoints);
98+
99+
const expectedArea = (0.4 - 0.1) * (0.6 - 0.2);
100+
expect(area).toBeCloseTo(expectedArea, 6);
101+
});
102+
103+
it('Apply tracker', () => {
104+
// Timestamp: 0. Pose becomes the only track.
105+
const kptTracker = new KeypointTracker(trackerConfig);
106+
let tracks: Track[];
107+
let poses: Pose[] = [
108+
{keypoints: [ // Becomes id = 1.
109+
{x: 0.2, y: 0.2, score: 1.0},
110+
{x: 0.4, y: 0.4, score: 0.8},
111+
{x: 0.6, y: 0.6, score: 0.9},
112+
{x: 0.8, y: 0.8, score: 0.0} // Low confidence.
113+
]}
114+
];
115+
poses = kptTracker.apply(poses, 0);
116+
tracks = kptTracker.getTracks();
117+
expect(poses.length).toEqual(1);
118+
expect(poses[0].id).toEqual(1);
119+
expect(tracks.length).toEqual(1);
120+
expect(tracks[0].id).toEqual(1);
121+
expect(tracks[0].lastTimestamp).toEqual(0);
122+
123+
// Timestamp: 100. First pose is linked with track 1. Second pose spawns a
124+
// new track (id = 2).
125+
poses = [
126+
{keypoints: [ // Links with id = 1.
127+
{x: 0.2, y: 0.2, score: 1.0},
128+
{x: 0.4, y: 0.4, score: 0.8},
129+
{x: 0.6, y: 0.6, score: 0.9},
130+
{x: 0.8, y: 0.8, score: 0.8}
131+
]},
132+
{keypoints: [ // Becomes id = 2.
133+
{x: 0.8, y: 0.8, score: 0.8},
134+
{x: 0.6, y: 0.6, score: 0.3},
135+
{x: 0.4, y: 0.4, score: 0.1}, // Low confidence.
136+
{x: 0.2, y: 0.2, score: 0.8}
137+
]}
138+
];
139+
poses = kptTracker.apply(poses, 100);
140+
tracks = kptTracker.getTracks();
141+
expect(poses.length).toEqual(2);
142+
expect(poses[0].id).toEqual(1);
143+
expect(poses[1].id).toEqual(2);
144+
expect(tracks.length).toEqual(2);
145+
expect(tracks[0].id).toEqual(1);
146+
expect(tracks[0].lastTimestamp).toEqual(100);
147+
expect(tracks[1].id).toEqual(2);
148+
expect(tracks[1].lastTimestamp).toEqual(100);
149+
150+
// Timestamp: 900. First pose is linked with track 2. Second pose spawns a
151+
// new track (id = 3).
152+
poses = [
153+
{keypoints: [ // Links with id = 2.
154+
{x: 0.6, y: 0.7, score: 0.7},
155+
{x: 0.5, y: 0.6, score: 0.7},
156+
{x: 0.0, y: 0.0, score: 0.1}, // Low confidence.
157+
{x: 0.2, y: 0.1, score: 1.0}
158+
]},
159+
{keypoints: [ // Becomes id = 3.
160+
{x: 0.5, y: 0.1, score: 0.6},
161+
{x: 0.9, y: 0.3, score: 0.6},
162+
{x: 0.1, y: 1.0, score: 0.9},
163+
{x: 0.4, y: 0.4, score: 0.1} // Low confidence.
164+
]},
165+
];
166+
poses = kptTracker.apply(poses, 900);
167+
tracks = kptTracker.getTracks();
168+
expect(poses.length).toEqual(2);
169+
expect(poses[0].id).toEqual(2);
170+
expect(poses[1].id).toEqual(3);
171+
expect(tracks.length).toEqual(3);
172+
expect(tracks[0].id).toEqual(2);
173+
expect(tracks[0].lastTimestamp).toEqual(900);
174+
expect(tracks[1].id).toEqual(3);
175+
expect(tracks[1].lastTimestamp).toEqual(900);
176+
expect(tracks[2].id).toEqual(1);
177+
expect(tracks[2].lastTimestamp).toEqual(100);
178+
179+
// Timestamp: 1200. First pose spawns a new track (id = 4), even though it
180+
// has the same keypoints as track 1. This is because the age exceeds 1000
181+
// msec. The second pose links with id 2. The third pose spawns a new
182+
// track (id = 5).
183+
poses = [
184+
{keypoints: [ // Becomes id = 4.
185+
{x: 0.2, y: 0.2, score: 1.0},
186+
{x: 0.4, y: 0.4, score: 0.8},
187+
{x: 0.6, y: 0.6, score: 0.9},
188+
{x: 0.8, y: 0.8, score: 0.8}
189+
]},
190+
{keypoints: [ // Links with id = 2.
191+
{x: 0.55, y: 0.7, score: 0.7},
192+
{x: 0.5, y: 0.6, score: 0.9},
193+
{x: 1.0, y: 1.0, score: 0.1}, // Low confidence.
194+
{x: 0.8, y: 0.1, score: 0.0} // Low confidence.
195+
]},
196+
{keypoints: [ // Becomes id = 5.
197+
{x: 0.1, y: 0.1, score: 0.1}, // Low confidence.
198+
{x: 0.2, y: 0.2, score: 0.9},
199+
{x: 0.3, y: 0.3, score: 0.7},
200+
{x: 0.4, y: 0.4, score: 0.8}
201+
]},
202+
];
203+
poses = kptTracker.apply(poses, 1200);
204+
tracks = kptTracker.getTracks();
205+
expect(poses.length).toEqual(3);
206+
expect(poses[0].id).toEqual(4);
207+
expect(poses[1].id).toEqual(2);
208+
expect(tracks.length).toEqual(4);
209+
expect(tracks[0].id).toEqual(2);
210+
expect(tracks[0].lastTimestamp).toEqual(1200);
211+
expect(tracks[1].id).toEqual(4);
212+
expect(tracks[1].lastTimestamp).toEqual(1200);
213+
expect(tracks[2].id).toEqual(5);
214+
expect(tracks[2].lastTimestamp).toEqual(1200);
215+
expect(tracks[3].id).toEqual(3);
216+
expect(tracks[3].lastTimestamp).toEqual(900);
217+
218+
// Timestamp: 1300. First pose spawns a new track (id = 6). Since maxTracks
219+
// is 4, the oldest track (id = 3) is removed.
220+
poses = [
221+
{keypoints: [ // Becomes id = 6.
222+
{x: 0.1, y: 0.8, score: 1.0},
223+
{x: 0.2, y: 0.9, score: 0.6},
224+
{x: 0.2, y: 0.9, score: 0.5},
225+
{x: 0.8, y: 0.2, score: 0.4}
226+
]},
227+
];
228+
poses = kptTracker.apply(poses, 1300);
229+
tracks = kptTracker.getTracks();
230+
expect(poses.length).toEqual(1);
231+
expect(poses[0].id).toEqual(6);
232+
expect(tracks.length).toEqual(4);
233+
expect(tracks[0].id).toEqual(6);
234+
expect(tracks[0].lastTimestamp).toEqual(1300);
235+
expect(tracks[1].id).toEqual(2);
236+
expect(tracks[1].lastTimestamp).toEqual(1200);
237+
expect(tracks[2].id).toEqual(4);
238+
expect(tracks[2].lastTimestamp).toEqual(1200);
239+
expect(tracks[3].id).toEqual(5);
240+
expect(tracks[3].lastTimestamp).toEqual(1200);
241+
});
242+
});

0 commit comments

Comments
 (0)