|
19 | 19 | import * as tfconv from '@tensorflow/tfjs-converter'; |
20 | 20 | import * as tf from '@tensorflow/tfjs-core'; |
21 | 21 |
|
| 22 | +import {BaseModel} from './base_model'; |
22 | 23 | import {decodeOnlyPartSegmentation, decodePartSegmentation, toMaskTensor} from './decode_part_map'; |
23 | | -import {MobileNet, MobileNetMultiplier} from './mobilenet'; |
| 24 | +import {MobileNet} from './mobilenet'; |
24 | 25 | import {decodePersonInstanceMasks, decodePersonInstancePartMasks} from './multi_person/decode_instance_masks'; |
25 | 26 | import {decodeMultiplePoses} from './multi_person/decode_multiple_poses'; |
26 | 27 | import {ResNet} from './resnet'; |
27 | 28 | import {mobileNetSavedModel, resNet50SavedModel} from './saved_models'; |
28 | | -import {decodeSinglePose} from './sinlge_person/decode_single_pose'; |
| 29 | +import {decodeSinglePose} from './single_person/decode_single_pose'; |
29 | 30 | import {BodyPixArchitecture, BodyPixInput, BodyPixInternalResolution, BodyPixMultiplier, BodyPixOutputStride, BodyPixQuantBytes, Padding, PartSegmentation, PersonSegmentation} from './types'; |
30 | 31 | import {getInputSize, padAndResizeTo, scaleAndCropToInputTensorShape, scaleAndFlipPoses, toTensorBuffers3D, toValidInternalResolutionNumber} from './util'; |
31 | 32 |
|
32 | | - |
33 | 33 | const APPLY_SIGMOID_ACTIVATION = true; |
34 | 34 |
|
35 | | -/** |
36 | | - * BodyPix supports using various convolution neural network models |
37 | | - * (e.g. ResNet and MobileNetV1) as its underlying base model. |
38 | | - * The following BaseModel interface defines a unified interface for |
39 | | - * creating such BodyPix base models. Currently both MobileNet (in |
40 | | - * ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel |
41 | | - * interface. New base models that conform to the BaseModel interface can be |
42 | | - * added to BodyPix. |
43 | | - */ |
44 | | -export interface BaseModel { |
45 | | - // The output stride of the base model. |
46 | | - readonly outputStride: BodyPixOutputStride; |
47 | | - |
48 | | - /** |
49 | | - * Predicts intermediate Tensor representations. |
50 | | - * |
51 | | - * @param input The input RGB image of the base model. |
52 | | - * A Tensor of shape: [`inputResolution`, `inputResolution`, 3]. |
53 | | - * |
54 | | - * @return A dictionary of base model's intermediate predictions. |
55 | | - * The returned dictionary should contains the following elements: |
56 | | - * - heatmapScores: A Tensor3D that represents the keypoint heatmap scores. |
57 | | - * - offsets: A Tensor3D that represents the offsets. |
58 | | - * - displacementFwd: A Tensor3D that represents the forward displacement. |
59 | | - * - displacementBwd: A Tensor3D that represents the backward displacement. |
60 | | - * - segmentation: A Tensor3D that represents the segmentation of all people. |
61 | | - * - longOffsets: A Tensor3D that represents the long offsets used for |
62 | | - * instance grouping. |
63 | | - * - partHeatmaps: A Tensor3D that represents the body part segmentation. |
64 | | - */ |
65 | | - predict(input: tf.Tensor3D): {[key: string]: tf.Tensor3D}; |
66 | | - /** |
67 | | - * Releases the CPU and GPU memory allocated by the model. |
68 | | - */ |
69 | | - dispose(): void; |
70 | | -} |
71 | | - |
72 | 35 | /** |
73 | 36 | * BodyPix model loading is configurable using the following config dictionary. |
74 | 37 | * |
@@ -101,7 +64,7 @@ export interface BaseModel { |
101 | 64 | export interface ModelConfig { |
102 | 65 | architecture: BodyPixArchitecture; |
103 | 66 | outputStride: BodyPixOutputStride; |
104 | | - multiplier?: MobileNetMultiplier; |
| 67 | + multiplier?: BodyPixMultiplier; |
105 | 68 | modelUrl?: string; |
106 | 69 | quantBytes?: BodyPixQuantBytes; |
107 | 70 | } |
@@ -602,15 +565,15 @@ export class BodyPix { |
602 | 565 | }; |
603 | 566 | }); |
604 | 567 |
|
605 | | - const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] = |
| 568 | + const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] = |
606 | 569 | await toTensorBuffers3D([ |
607 | 570 | heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw |
608 | 571 | ]); |
609 | 572 |
|
610 | | - let poses = await decodeMultiplePoses( |
611 | | - scoresBuffer, offsetsBuffer, displacementsFwdBuffer, |
612 | | - displacementsBwdBuffer, this.baseModel.outputStride, |
613 | | - config.maxDetections, config.scoreThreshold, config.nmsRadius); |
| 573 | + let poses = decodeMultiplePoses( |
| 574 | + scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf, |
| 575 | + this.baseModel.outputStride, config.maxDetections, |
| 576 | + config.scoreThreshold, config.nmsRadius); |
614 | 577 |
|
615 | 578 | poses = scaleAndFlipPoses( |
616 | 579 | poses, [height, width], |
@@ -849,15 +812,15 @@ export class BodyPix { |
849 | 812 | }; |
850 | 813 | }); |
851 | 814 |
|
852 | | - const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] = |
| 815 | + const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] = |
853 | 816 | await toTensorBuffers3D([ |
854 | 817 | heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw |
855 | 818 | ]); |
856 | 819 |
|
857 | | - let poses = await decodeMultiplePoses( |
858 | | - scoresBuffer, offsetsBuffer, displacementsFwdBuffer, |
859 | | - displacementsBwdBuffer, this.baseModel.outputStride, |
860 | | - config.maxDetections, config.scoreThreshold, config.nmsRadius); |
| 820 | + let poses = decodeMultiplePoses( |
| 821 | + scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf, |
| 822 | + this.baseModel.outputStride, config.maxDetections, |
| 823 | + config.scoreThreshold, config.nmsRadius); |
861 | 824 |
|
862 | 825 | poses = scaleAndFlipPoses( |
863 | 826 | poses, [height, width], |
|
0 commit comments