Skip to content

Commit 9381dd8

Browse files
authored
Improvements to the speech recognition model (#101)
Several improvements/fixes to the speech recognition model: - Stop the audio extractor immediately after calling `recognizeOnline()`. Otherwise the extractor is in an infinite loop of processing data. - When the extractor is stopped, close the audio stream by calling `this.stream.getTracks()[0].stop();`. Otherwise the browser shows that the microphone is always on (red dot on the browser tab). - Return the original spectrogram data to the user (unnormalized). This makes a large difference in data quality for applications that decide to use only the last K frames out of the 43 frames. Those K frames are not impacted by the mean and stdev of all of the 43 frames. Before, my audio app will react with a delay of ~1sec after I make a sound even when using only the last 3 frames (70ms of data). The delay was because of the silence the preceded my sound which shifted the mean and the stdev. - Speedup data collection by 25% by using a queue instead of a circular buffer, which allows us to use the fast built-in `TypedArray.set()` when creating the extractor. Now calling listen with overlapFactor of 0.999 results in ~40 callbacks/sec - before it was ~32. - Update the patch version. Now version is `0.2.1`
1 parent b9aaade commit 9381dd8

9 files changed

+78
-128
lines changed

speech-commands/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@tensorflow-models/speech-commands",
3-
"version": "0.2.0",
3+
"version": "0.2.1",
44
"description": "Speech-command recognizer in TensorFlow.js",
55
"main": "dist/index.js",
66
"unpkg": "dist/speech-commands.min.js",

speech-commands/src/browser_fft_extractor.ts

Lines changed: 29 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
*/
2121

2222
import * as tf from '@tensorflow/tfjs';
23-
import {getAudioContextConstructor, getAudioMediaStream, normalize} from './browser_fft_utils';
23+
import {getAudioContextConstructor, getAudioMediaStream} from './browser_fft_utils';
2424
import {FeatureExtractor, RecognizerParams} from './types';
2525

2626
export type SpectrogramCallback = (x: tf.Tensor) => Promise<boolean>;
@@ -77,7 +77,7 @@ export interface BrowserFftFeatureExtractorConfig extends RecognizerParams {
7777
*/
7878
export class BrowserFftFeatureExtractor implements FeatureExtractor {
7979
// Number of frames (i.e., columns) per spectrogram used for classification.
80-
readonly numFramesPerSpectrogram: number;
80+
readonly numFrames: number;
8181

8282
// Audio sampling rate in Hz.
8383
readonly sampleRateHz: number;
@@ -92,22 +92,16 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
9292
// consecutive spectrograms and the length of each individual spectrogram.
9393
readonly overlapFactor: number;
9494

95-
protected readonly spectrogramCallback: SpectrogramCallback;
95+
private readonly spectrogramCallback: SpectrogramCallback;
9696

9797
private stream: MediaStream;
9898
// tslint:disable-next-line:no-any
9999
private audioContextConstructor: any;
100100
private audioContext: AudioContext;
101101
private analyser: AnalyserNode;
102-
103102
private tracker: Tracker;
104-
105-
private readonly ROTATING_BUFFER_SIZE_MULTIPLIER = 2;
106103
private freqData: Float32Array;
107-
private rotatingBufferNumFrames: number;
108-
private rotatingBuffer: Float32Array;
109-
110-
private frameCount: number;
104+
private freqDataQueue: Float32Array[];
111105
// tslint:disable-next-line:no-any
112106
private frameIntervalTask: any;
113107
private frameDurationMillis: number;
@@ -144,7 +138,7 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
144138
this.suppressionTimeMillis = config.suppressionTimeMillis;
145139

146140
this.spectrogramCallback = config.spectrogramCallback;
147-
this.numFramesPerSpectrogram = config.numFramesPerSpectrogram;
141+
this.numFrames = config.numFramesPerSpectrogram;
148142
this.sampleRateHz = config.sampleRateHz || 44100;
149143
this.fftSize = config.fftSize || 1024;
150144
this.frameDurationMillis = this.fftSize / this.sampleRateHz * 1e3;
@@ -165,7 +159,7 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
165159
this.audioContextConstructor = getAudioContextConstructor();
166160
}
167161

168-
async start(samples?: Float32Array): Promise<Float32Array[]|void> {
162+
async start(): Promise<Float32Array[]|void> {
169163
if (this.frameIntervalTask != null) {
170164
throw new Error(
171165
'Cannot start already-started BrowserFftFeatureExtractor');
@@ -184,18 +178,11 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
184178
this.analyser.fftSize = this.fftSize * 2;
185179
this.analyser.smoothingTimeConstant = 0.0;
186180
streamSource.connect(this.analyser);
187-
181+
// Reset the queue.
182+
this.freqDataQueue = [];
188183
this.freqData = new Float32Array(this.fftSize);
189-
this.rotatingBufferNumFrames =
190-
this.numFramesPerSpectrogram * this.ROTATING_BUFFER_SIZE_MULTIPLIER;
191-
const rotatingBufferSize =
192-
this.columnTruncateLength * this.rotatingBufferNumFrames;
193-
this.rotatingBuffer = new Float32Array(rotatingBufferSize);
194-
195-
this.frameCount = 0;
196-
197-
const period = Math.max(
198-
1, Math.round(this.numFramesPerSpectrogram * (1 - this.overlapFactor)));
184+
const period =
185+
Math.max(1, Math.round(this.numFrames * (1 - this.overlapFactor)));
199186
this.tracker = new Tracker(
200187
period,
201188
Math.round(this.suppressionTimeMillis / this.frameDurationMillis));
@@ -209,20 +196,16 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
209196
return;
210197
}
211198

212-
const freqDataSlice = this.freqData.slice(0, this.columnTruncateLength);
213-
const bufferPos = this.frameCount % this.rotatingBufferNumFrames;
214-
this.rotatingBuffer.set(
215-
freqDataSlice, bufferPos * this.columnTruncateLength);
216-
this.frameCount++;
217-
199+
this.freqDataQueue.push(this.freqData.slice(0, this.columnTruncateLength));
200+
if (this.freqDataQueue.length > this.numFrames) {
201+
// Drop the oldest frame (least recent).
202+
this.freqDataQueue.shift();
203+
}
218204
const shouldFire = this.tracker.tick();
219205
if (shouldFire) {
220-
const freqData = getFrequencyDataFromRotatingBuffer(
221-
this.rotatingBuffer, this.numFramesPerSpectrogram,
222-
this.columnTruncateLength,
223-
this.frameCount - this.numFramesPerSpectrogram);
206+
const freqData = flattenQueue(this.freqDataQueue);
224207
const inputTensor = getInputTensorFromFrequencyData(
225-
freqData, this.numFramesPerSpectrogram, this.columnTruncateLength);
208+
freqData, [1, this.numFrames, this.columnTruncateLength, 1]);
226209
const shouldRest = await this.spectrogramCallback(inputTensor);
227210
if (shouldRest) {
228211
this.tracker.suppress();
@@ -240,6 +223,9 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
240223
this.frameIntervalTask = null;
241224
this.analyser.disconnect();
242225
this.audioContext.close();
226+
if (this.stream != null && this.stream.getTracks().length > 0) {
227+
this.stream.getTracks()[0].stop();
228+
}
243229
}
244230

245231
setConfig(params: RecognizerParams) {
@@ -255,39 +241,19 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
255241
}
256242
}
257243

258-
export function getFrequencyDataFromRotatingBuffer(
259-
rotatingBuffer: Float32Array, numFrames: number, fftLength: number,
260-
frameCount: number): Float32Array {
261-
const size = numFrames * fftLength;
262-
const freqData = new Float32Array(size);
263-
264-
const rotatingBufferSize = rotatingBuffer.length;
265-
const rotatingBufferNumFrames = rotatingBufferSize / fftLength;
266-
while (frameCount < 0) {
267-
frameCount += rotatingBufferNumFrames;
268-
}
269-
const indexBegin = (frameCount % rotatingBufferNumFrames) * fftLength;
270-
const indexEnd = indexBegin + size;
271-
272-
for (let i = indexBegin; i < indexEnd; ++i) {
273-
freqData[i - indexBegin] = rotatingBuffer[i % rotatingBufferSize];
274-
}
244+
export function flattenQueue(queue: Float32Array[]): Float32Array {
245+
const frameSize = queue[0].length;
246+
const freqData = new Float32Array(queue.length * frameSize);
247+
queue.forEach((data, i) => freqData.set(data, i * frameSize));
275248
return freqData;
276249
}
277250

278251
export function getInputTensorFromFrequencyData(
279-
freqData: Float32Array, numFrames: number, fftLength: number,
280-
toNormalize = true): tf.Tensor {
281-
return tf.tidy(() => {
282-
const size = freqData.length;
283-
const tensorBuffer = tf.buffer([size]);
284-
for (let i = 0; i < freqData.length; ++i) {
285-
tensorBuffer.set(freqData[i], i);
286-
}
287-
const output =
288-
tensorBuffer.toTensor().reshape([1, numFrames, fftLength, 1]);
289-
return toNormalize ? normalize(output) : output;
290-
});
252+
freqData: Float32Array, shape: number[]): tf.Tensor {
253+
const vals = new Float32Array(tf.util.sizeFromShape(shape));
254+
// If the data is less than the output shape, the rest is padded with zeros.
255+
vals.set(freqData, vals.length - freqData.length);
256+
return tf.tensor(vals, shape);
291257
}
292258

293259
/**

speech-commands/src/browser_fft_extractor_test.ts

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,62 +17,38 @@
1717

1818
import * as tf from '@tensorflow/tfjs';
1919
import {describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';
20-
import {BrowserFftFeatureExtractor, getFrequencyDataFromRotatingBuffer, getInputTensorFromFrequencyData} from './browser_fft_extractor';
20+
import {BrowserFftFeatureExtractor, flattenQueue, getInputTensorFromFrequencyData} from './browser_fft_extractor';
2121
import * as BrowserFftUtils from './browser_fft_utils';
2222
import {FakeAudioContext, FakeAudioMediaStream} from './browser_test_utils';
2323

2424
const testEnvs = tf.test_util.NODE_ENVS;
2525

26-
describeWithFlags('getFrequencyDataFromRotatingBuffer', testEnvs, () => {
27-
it('getFrequencyDataFromRotatingBuffer', () => {
28-
const rotBuffer = new Float32Array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]);
29-
const numFrames = 3;
30-
const fftLength = 2;
31-
expect(
32-
getFrequencyDataFromRotatingBuffer(rotBuffer, numFrames, fftLength, 0))
33-
.toEqual(new Float32Array([1, 1, 2, 2, 3, 3]));
34-
35-
expect(
36-
getFrequencyDataFromRotatingBuffer(rotBuffer, numFrames, fftLength, 1))
37-
.toEqual(new Float32Array([2, 2, 3, 3, 4, 4]));
38-
expect(
39-
getFrequencyDataFromRotatingBuffer(rotBuffer, numFrames, fftLength, 3))
40-
.toEqual(new Float32Array([4, 4, 5, 5, 6, 6]));
41-
expect(
42-
getFrequencyDataFromRotatingBuffer(rotBuffer, numFrames, fftLength, 4))
43-
.toEqual(new Float32Array([5, 5, 6, 6, 1, 1]));
44-
expect(
45-
getFrequencyDataFromRotatingBuffer(rotBuffer, numFrames, fftLength, 6))
46-
.toEqual(new Float32Array([1, 1, 2, 2, 3, 3]));
26+
describeWithFlags('flattenQueue', testEnvs, () => {
27+
it('3 frames, 2 values each', () => {
28+
const queue = [[1, 1], [2, 2], [3, 3]].map(x => new Float32Array(x));
29+
expect(flattenQueue(queue)).toEqual(new Float32Array([1, 1, 2, 2, 3, 3]));
30+
});
31+
32+
it('2 frames, 2 values each', () => {
33+
const queue = [[1, 1], [2, 2]].map(x => new Float32Array(x));
34+
expect(flattenQueue(queue)).toEqual(new Float32Array([1, 1, 2, 2]));
35+
});
36+
37+
it('1 frame, 2 values each', () => {
38+
const queue = [[1, 1]].map(x => new Float32Array(x));
39+
expect(flattenQueue(queue)).toEqual(new Float32Array([1, 1]));
4740
});
4841
});
4942

5043
describeWithFlags('getInputTensorFromFrequencyData', testEnvs, () => {
51-
it('Unnormalized', () => {
44+
it('6 frames, 2 vals each', () => {
5245
const freqData = new Float32Array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]);
5346
const numFrames = 6;
5447
const fftSize = 2;
5548
const tensor =
56-
getInputTensorFromFrequencyData(freqData, numFrames, fftSize, false);
49+
getInputTensorFromFrequencyData(freqData, [1, numFrames, fftSize, 1]);
5750
tf.test_util.expectArraysClose(tensor, tf.tensor4d(freqData, [1, 6, 2, 1]));
5851
});
59-
60-
it('Normalized', () => {
61-
const freqData = new Float32Array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]);
62-
const numFrames = 6;
63-
const fftSize = 2;
64-
const tensor =
65-
getInputTensorFromFrequencyData(freqData, numFrames, fftSize);
66-
tf.test_util.expectArraysClose(
67-
tensor,
68-
tf.tensor4d(
69-
[
70-
-1.4638501, -1.4638501, -0.8783101, -0.8783101, -0.29277,
71-
-0.29277, 0.29277, 0.29277, 0.8783101, 0.8783101, 1.4638501,
72-
1.4638501
73-
],
74-
[1, 6, 2, 1]));
75-
});
7652
});
7753

7854
describeWithFlags('BrowserFftFeatureExtractor', testEnvs, () => {
@@ -95,7 +71,7 @@ describeWithFlags('BrowserFftFeatureExtractor', testEnvs, () => {
9571
});
9672

9773
expect(extractor.fftSize).toEqual(1024);
98-
expect(extractor.numFramesPerSpectrogram).toEqual(43);
74+
expect(extractor.numFrames).toEqual(43);
9975
expect(extractor.columnTruncateLength).toEqual(225);
10076
expect(extractor.overlapFactor).toBeCloseTo(0);
10177
});

speech-commands/src/browser_fft_recognizer.ts

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import * as tf from '@tensorflow/tfjs';
1919
import {BrowserFftFeatureExtractor, SpectrogramCallback} from './browser_fft_extractor';
20-
import {loadMetadataJson} from './browser_fft_utils';
20+
import {loadMetadataJson, normalize} from './browser_fft_utils';
2121
import {RecognizeConfig, RecognizerCallback, RecognizerParams, SpectrogramData, SpeechCommandRecognizer, SpeechCommandRecognizerResult, StreamingRecognitionConfig, TransferLearnConfig, TransferSpeechCommandRecognizer} from './types';
2222
import {version} from './version';
2323

@@ -41,7 +41,7 @@ export class BrowserFftSpeechCommandRecognizer implements
4141

4242
readonly MODEL_URL_PREFIX =
4343
`https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v${
44-
getMajorAndMinorVersion(version)}/browser_fft`;
44+
getMajorAndMinorVersion(version)}/browser_fft`;
4545

4646
private readonly SAMPLE_RATE_HZ = 44100;
4747
private readonly FFT_SIZE = 1024;
@@ -137,8 +137,9 @@ export class BrowserFftSpeechCommandRecognizer implements
137137
* @throws Error, if streaming recognition is already started or
138138
* if `config` contains invalid values.
139139
*/
140-
async listen(callback: RecognizerCallback,
141-
config?: StreamingRecognitionConfig): Promise<void> {
140+
async listen(
141+
callback: RecognizerCallback,
142+
config?: StreamingRecognitionConfig): Promise<void> {
142143
if (streaming) {
143144
throw new Error(
144145
'Cannot start streaming again when streaming is ongoing.');
@@ -183,21 +184,22 @@ export class BrowserFftSpeechCommandRecognizer implements
183184
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
184185
await this.ensureModelWithEmbeddingOutputCreated();
185186

187+
const normalizedX = normalize(x);
186188
let y: tf.Tensor;
187189
let embedding: tf.Tensor;
188190
if (config.includeEmbedding) {
189191
await this.ensureModelWithEmbeddingOutputCreated();
190192
[y, embedding] =
191-
this.modelWithEmbeddingOutput.predict(x) as tf.Tensor[];
193+
this.modelWithEmbeddingOutput.predict(normalizedX) as tf.Tensor[];
192194
} else {
193-
y = this.model.predict(x) as tf.Tensor;
195+
y = this.model.predict(normalizedX) as tf.Tensor;
194196
}
195197

196198
const scores = await y.data() as Float32Array;
197199
const maxIndexTensor = y.argMax(-1);
198200
const maxIndex = (await maxIndexTensor.data())[0];
199201
const maxScore = Math.max(...scores);
200-
tf.dispose([y, maxIndexTensor]);
202+
tf.dispose([y, maxIndexTensor, normalizedX]);
201203

202204
if (maxScore < probabilityThreshold) {
203205
return false;
@@ -486,22 +488,25 @@ export class BrowserFftSpeechCommandRecognizer implements
486488

487489
if (config.includeSpectrogram) {
488490
output.spectrogram = {
489-
data: (input instanceof tf.Tensor ?
490-
await input.data() : input) as Float32Array,
491+
data: (input instanceof tf.Tensor ? await input.data() : input) as
492+
Float32Array,
491493
frameSize: this.nonBatchInputShape[1],
492494
};
493495
}
494496

495497
return output;
496498
}
497499

498-
protected async recognizeOnline(): Promise<SpectrogramData> {
500+
private async recognizeOnline(): Promise<SpectrogramData> {
499501
return new Promise<SpectrogramData>((resolve, reject) => {
500502
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
503+
const normalizedX = normalize(x);
504+
await this.audioDataExtractor.stop();
501505
resolve({
502-
data: await x.data() as Float32Array,
506+
data: await normalizedX.data() as Float32Array,
503507
frameSize: this.nonBatchInputShape[1],
504508
});
509+
normalizedX.dispose();
505510
return false;
506511
};
507512
this.audioDataExtractor = new BrowserFftFeatureExtractor({
@@ -611,15 +616,17 @@ class TransferBrowserFftSpeechCommandRecognizer extends
611616
`learning example`);
612617

613618
streaming = true;
614-
return new Promise<SpectrogramData>((resolve, reject) => {
619+
return new Promise<SpectrogramData>(resolve => {
615620
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
616621
if (this.transferExamples == null) {
617622
this.transferExamples = {};
618623
}
619624
if (this.transferExamples[word] == null) {
620625
this.transferExamples[word] = [];
621626
}
622-
this.transferExamples[word].push(x.clone());
627+
const normalizedX = normalize(x);
628+
this.transferExamples[word].push(normalizedX.clone());
629+
normalizedX.dispose();
623630
await this.audioDataExtractor.stop();
624631
streaming = false;
625632
this.collateTransferWords();

speech-commands/src/browser_fft_recognizer_test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,7 @@ describeWithFlags('Browser FFT recognizer', tf.test_util.NODE_ENVS, () => {
775775
setUpFakes();
776776
const base = new BrowserFftSpeechCommandRecognizer();
777777
await base.ensureModelLoaded();
778-
await base.listen(
779-
async (result: SpeechCommandRecognizerResult) => {});
778+
await base.listen(async (result: SpeechCommandRecognizerResult) => {});
780779
expect(base.isListening()).toEqual(true);
781780

782781
const transfer = base.createTransfer('xfer1');

speech-commands/src/browser_fft_utils.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ export async function loadMetadataJson(url: string):
4040

4141
export function normalize(x: tf.Tensor): tf.Tensor {
4242
return tf.tidy(() => {
43-
const mean = tf.mean(x);
44-
const std = tf.sqrt(tf.mean(tf.square(tf.add(x, tf.neg(mean)))));
45-
return tf.div(tf.add(x, tf.neg(mean)), std);
43+
const {mean, variance} = tf.moments(x);
44+
return x.sub(mean).div(variance.sqrt());
4645
});
4746
}
4847

0 commit comments

Comments
 (0)