|
| 1 | +# Semantic Segmentation in the Browser: DeepLab v3 Model |
| 2 | + |
| 3 | +## This model is a work-in-progress and has not been released yet. We will update this README when the model is released and usable |
| 4 | + |
| 5 | +This package contains a standalone implementation of the DeepLab inference pipeline, as well as a [demo](./demo), for running semantic segmentation using TensorFlow.js. |
| 6 | + |
| 7 | + |
| 8 | + |
| 9 | +## Usage |
| 10 | + |
| 11 | +In the first step of semantic segmentation, an image is fed through a pre-trained model [based](https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md) on MobileNet-v2. Three types of pre-trained weights are available, trained on [Pascal](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html), [Cityscapes](https://www.cityscapes-dataset.com) and [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) datasets. |
| 12 | + |
| 13 | +To get started, pick the model name from `pascal`, `cityscapes` and `ade20k`, and decide whether you want your model quantized to 1 or 2 bytes (set the `quantizationBytes` option to 4 if you want to disable quantization). Then, initialize the model as follows: |
| 14 | + |
| 15 | +```typescript |
| 16 | +import * as tf from '@tensorflow-models/tfjs'; |
| 17 | +import * as deeplab from '@tensorflow-models/deeplab'; |
| 18 | +const loadModel = async () => { |
| 19 | + const modelName = 'pascal'; // set to your preferred model, out of `pascal`, |
| 20 | + // `cityscapes` and `ade20k` |
| 21 | + const quantizationBytes = 2; // either 1, 2 or 4 |
| 22 | + return await deeplab.load({base: modelName, quantizationBytes}); |
| 23 | +}; |
| 24 | + |
| 25 | +const input = tf.zeros([227, 500, 3]); |
| 26 | +// ... |
| 27 | + |
| 28 | +loadModel() |
| 29 | + .then((model) => model.segment(input)) |
| 30 | + .then( |
| 31 | + ({legend}) => |
| 32 | + console.log(`The predicted classes are ${JSON.stringify(legend)}`)); |
| 33 | +``` |
| 34 | + |
| 35 | +By default, calling `load` initalizes the PASCAL variant of the model quantized to 2 bytes. |
| 36 | + |
| 37 | +If you would rather load custom weights, you can pass the URL in the config instead: |
| 38 | + |
| 39 | +```typescript |
| 40 | +import * as deeplab from '@tensorflow-models/deeplab'; |
| 41 | +const loadModel = async () => { |
| 42 | + // #TODO(tfjs): Replace this URL after you host the model |
| 43 | + const url = 'https://storage.googleapis.com/gsoc-tfjs/models/deeplab/quantized/1/pascal/model.json'; |
| 44 | + return await deeplab.load({modelUrl: url}); |
| 45 | +}; |
| 46 | +loadModel().then(() => console.log(`Loaded the model successfully!`)); |
| 47 | +``` |
| 48 | + |
| 49 | +This will initialize and return the `SemanticSegmentation` model. |
| 50 | + |
| 51 | +You can set the `base` attribute in the argument to `pascal`, `cityscapes` or `ade20k` to use the corresponding colormap and labelling scheme. Otherwise, you would have to provide those yourself during segmentation. |
| 52 | + |
| 53 | +If you require more careful control over the initialization and behavior of the model (e.g. you want to use your own labelling scheme and colormap), use the `SemanticSegmentation` class, passing a pre-loaded `GraphModel` in the constructor: |
| 54 | + |
| 55 | +```typescript |
| 56 | +import * as tfconv from '@tensorflow/tfjs-converter'; |
| 57 | +import * as deeplab from '@tensorflow-models/deeplab'; |
| 58 | +const loadModel = async () => { |
| 59 | + const base = 'pascal'; // set to your preferred model, out of `pascal`, |
| 60 | + // `cityscapes` and `ade20k` |
| 61 | + const quantizationBytes = 2; // either 1, 2 or 4 |
| 62 | + // use the getURL utility function to get the URL to the pre-trained weights |
| 63 | + const modelUrl = deeplab.getURL(base, quantizationBytes); |
| 64 | + const rawModel = await tfconv.loadGraphModel(modelUrl); |
| 65 | + const modelName = 'pascal'; // set to your preferred model, out of `pascal`, |
| 66 | + // `cityscapes` and `ade20k` |
| 67 | + return new deeplab.SemanticSegmentation(rawModel); |
| 68 | +}; |
| 69 | +loadModel().then(() => console.log(`Loaded the model successfully!`)); |
| 70 | +``` |
| 71 | + |
| 72 | +Use `getColormap(base)` and `getLabels(base)` utility function to fetch the default colormap and labelling scheme. |
| 73 | + |
| 74 | +```typescript |
| 75 | +import {getLabels, getColormap} from '@tensorflow-models/deeplab'; |
| 76 | +const model = 'ade20k'; |
| 77 | +const colormap = getColormap(model); |
| 78 | +const labels = getLabels(model); |
| 79 | +``` |
| 80 | + |
| 81 | +### Segmenting an Image |
| 82 | + |
| 83 | +The `segment` method of the `SemanticSegmentation` object covers most use cases. |
| 84 | + |
| 85 | +Each model recognises a different set of object classes in an image: |
| 86 | + |
| 87 | +- [PASCAL](./deeplab/src/config.ts#L60) |
| 88 | +- [CityScapes](./deeplab/src/config.ts#L66) |
| 89 | +- [ADE20K](./deeplab/src/config.ts#L72) |
| 90 | + |
| 91 | +#### `model.segment(image, config?)` inputs |
| 92 | + |
| 93 | +- **image** :: `ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | tf.Tensor3D`; |
| 94 | + |
| 95 | + The image to segment |
| 96 | + |
| 97 | +- **config.canvas** (optional) :: `HTMLCanvasElement` |
| 98 | + |
| 99 | + Pass an optional canvas element as `canvas` to draw the output |
| 100 | + |
| 101 | +- **config.colormap** (optional) :: `[number, number, number][]` |
| 102 | + |
| 103 | + The array of RGB colors corresponding to labels |
| 104 | + |
| 105 | +- **config.labels** (optional) :: `string[]` |
| 106 | + |
| 107 | + The array of names corresponding to labels |
| 108 | + |
| 109 | + By [default](./src/index.ts#L81), `colormap` and `labels` are set according to the `base` model attribute passed during initialization. |
| 110 | + |
| 111 | +#### `model.segment(image, config?)` outputs |
| 112 | + |
| 113 | +The output is a promise of a `DeepLabOutput` object, with four attributes: |
| 114 | + |
| 115 | +- **legend** :: `{ [name: string]: [number, number, number] }` |
| 116 | + |
| 117 | + The legend is a dictionary of objects recognized in the image and their colors in RGB format. |
| 118 | + |
| 119 | +- **height** :: `number` |
| 120 | + |
| 121 | + The height of the returned segmentation map |
| 122 | + |
| 123 | +- **width** :: `number` |
| 124 | + |
| 125 | + The width of the returned segmentation map |
| 126 | + |
| 127 | +- **segmentationMap** :: `Uint8ClampedArray` |
| 128 | + |
| 129 | + The colored segmentation map as `Uint8ClampedArray` which can be [fed](https://developer.mozilla.org/en-US/docs/Web/API/Canvas_API/Tutorial/Pixel_manipulation_with_canvas) into `ImageData` and mapped to a canvas. |
| 130 | + |
| 131 | +#### `model.segment(image, config?)` example |
| 132 | + |
| 133 | +```typescript |
| 134 | +const classify = async (image) => { |
| 135 | + return await model.segment(image); |
| 136 | +} |
| 137 | +``` |
| 138 | + |
| 139 | +**Note**: *For more granular control, consider `predict` and `toSegmentationImage` methods described below.* |
| 140 | + |
| 141 | +### Producing a Semantic Segmentation Map |
| 142 | + |
| 143 | +To segment an arbitrary image and generate a two-dimensional tensor with class labels assigned to each cell of the grid overlayed on the image (with the maximum number of cells on the side fixed to 513), use the `predict` method of the `SemanticSegmentation` object. |
| 144 | + |
| 145 | +#### `model.predict(image)` input |
| 146 | + |
| 147 | +- **image** :: `ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | tf.Tensor3D`; |
| 148 | + |
| 149 | + The image to segment |
| 150 | + |
| 151 | +#### `model.predict(image)` output |
| 152 | + |
| 153 | +- **rawSegmentationMap** :: `tf.Tensor2D` |
| 154 | + |
| 155 | + The segmentation map of the image |
| 156 | + |
| 157 | +#### `model.predict(image)` example |
| 158 | + |
| 159 | +```javascript |
| 160 | +const getSemanticSegmentationMap = (image) => { |
| 161 | + return model.predict(image) |
| 162 | +} |
| 163 | +``` |
| 164 | + |
| 165 | +### Translating a Segmentation Map into the Color-Labelled Image |
| 166 | + |
| 167 | +To transform the segmentation map into a coloured image, use the `toSegmentationImage` method. |
| 168 | + |
| 169 | +#### `toSegmentationImage(colormap, labels, segmentationMap, canvas?)` inputs |
| 170 | + |
| 171 | +- **colormap** :: `[number, number, number][]` |
| 172 | + |
| 173 | + The array of RGB colors corresponding to labels |
| 174 | + |
| 175 | +- **labels** :: `string[]` |
| 176 | + |
| 177 | + The array of names corresponding to labels |
| 178 | + |
| 179 | +- **segmentationMap** :: `tf.Tensor2D` |
| 180 | + |
| 181 | + The segmentation map of the image |
| 182 | + |
| 183 | +- **canvas** (optional) :: `HTMLCanvasElement` |
| 184 | + |
| 185 | + Pass an optional canvas element as `canvas` to draw the output |
| 186 | + |
| 187 | +#### `toSegmentationImage(colormap, labels, segmentationMap, canvas?)` outputs |
| 188 | + |
| 189 | + A promise resolving to the `SegmentationData` object that contains two attributes: |
| 190 | + |
| 191 | +- **legend** :: `{ [name: string]: [number, number, number] }` |
| 192 | + |
| 193 | + The legend is a dictionary of objects recognized in the image and their colors. |
| 194 | + |
| 195 | +- **segmentationMap** :: `Uint8ClampedArray` |
| 196 | + |
| 197 | + The colored segmentation map as `Uint8ClampedArray` which can be [fed](https://developer.mozilla.org/en-US/docs/Web/API/Canvas_API/Tutorial/Pixel_manipulation_with_canvas) into `ImageData` and mapped to a canvas. |
| 198 | + |
| 199 | +#### `toSegmentationImage(colormap, labels, segmentationMap, canvas?)` example |
| 200 | + |
| 201 | +```javascript |
| 202 | +const base = 'pascal'; |
| 203 | +const translateSegmentationMap = async (segmentationMap) => { |
| 204 | + return await toSegmentationImage( |
| 205 | + getColormap(base), getLabels(base), segmentationMap) |
| 206 | +} |
| 207 | +``` |
| 208 | + |
| 209 | +## Contributing to the Demo |
| 210 | + |
| 211 | +Please see the demo [documentation](./demo/README.md). |
| 212 | + |
| 213 | +## Technical Details |
| 214 | + |
| 215 | +This model is based on the TensorFlow [implementation](https://github.com/tensorflow/models/tree/master/research/deeplab) of DeepLab v3. You might want to inspect the [conversion script](./scripts/convert_deeplab.sh), or download original pre-trained weights [here](https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md). To convert the weights locally, run the script as follows, replacing `dist` with the target directory: |
| 216 | + |
| 217 | +```bash |
| 218 | +./scripts/convert_deeplab.sh --target_dir ./scripts/dist |
| 219 | +``` |
| 220 | + |
| 221 | +Run the usage helper to learn more about the options: |
| 222 | + |
| 223 | +```bash |
| 224 | +./scripts/convert_deeplab.sh -h |
| 225 | +``` |
0 commit comments