Skip to content

Commit 331a065

Browse files
authored
[speech-command] Add includeEmbeddings; Allow recognize() to draw from WebAudio directly (#99)
- Allow `recognize()` to take no argument, in which case the method will draw a frame of audio directly from WebAudio. - Add the config field `includeEmbedding` to `startStreaming()` and `recognize()` methods.
1 parent 9fbbea6 commit 331a065

File tree

3 files changed

+252
-20
lines changed

3 files changed

+252
-20
lines changed

speech-commands/src/browser_fft_recognizer.ts

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import * as tf from '@tensorflow/tfjs';
2020
// tslint:disable:max-line-length
2121
import {BrowserFftFeatureExtractor, SpectrogramCallback} from './browser_fft_extractor';
2222
import {loadMetadataJson} from './browser_fft_utils';
23-
import {RecognizerCallback, RecognizerParams, SpectrogramData, SpeechCommandRecognizer, SpeechCommandRecognizerResult, StreamingRecognitionConfig, TransferLearnConfig, TransferSpeechCommandRecognizer} from './types';
23+
import {RecognizeConfig, RecognizerCallback, RecognizerParams, SpectrogramData, SpeechCommandRecognizer, SpeechCommandRecognizerResult, StreamingRecognitionConfig, TransferLearnConfig, TransferSpeechCommandRecognizer} from './types';
2424
import {version} from './version';
2525

2626
// tslint:enable:max-line-length
@@ -47,6 +47,7 @@ export class BrowserFftSpeechCommandRecognizer implements
4747
private readonly DEFAULT_SUPPRESSION_TIME_MILLIS = 1000;
4848

4949
model: tf.Model;
50+
modelWithEmbeddingOutput: tf.Model;
5051
readonly vocabulary: string;
5152
readonly parameters: RecognizerParams;
5253
protected words: string[];
@@ -149,15 +150,24 @@ export class BrowserFftSpeechCommandRecognizer implements
149150
if (config == null) {
150151
config = {};
151152
}
152-
const probabilityThreshold =
153+
let probabilityThreshold =
153154
config.probabilityThreshold == null ? 0 : config.probabilityThreshold;
155+
if (config.includeEmbedding) {
156+
// Override probability threshold to 0 if includeEmbedding is true.
157+
probabilityThreshold = 0;
158+
}
154159
tf.util.assert(
155160
probabilityThreshold >= 0 && probabilityThreshold <= 1,
156161
`Invalid probabilityThreshold value: ${probabilityThreshold}`);
157-
const invokeCallbackOnNoiseAndUnknown =
162+
let invokeCallbackOnNoiseAndUnknown =
158163
config.invokeCallbackOnNoiseAndUnknown == null ?
159164
false :
160165
config.invokeCallbackOnNoiseAndUnknown;
166+
if (config.includeEmbedding) {
167+
// Override invokeCallbackOnNoiseAndUnknown threshold to true if
168+
// includeEmbedding is true.
169+
invokeCallbackOnNoiseAndUnknown = true;
170+
}
161171

162172
if (config.suppressionTimeMillis < 0) {
163173
throw new Error(
@@ -174,7 +184,18 @@ export class BrowserFftSpeechCommandRecognizer implements
174184
Math.round(this.FFT_SIZE * (1 - overlapFactor));
175185

176186
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
177-
const y = tf.tidy(() => this.model.predict(x) as tf.Tensor);
187+
await this.ensureModelWithEmbeddingOutputCreated();
188+
189+
let y: tf.Tensor;
190+
let embedding: tf.Tensor;
191+
if (config.includeEmbedding) {
192+
await this.ensureModelWithEmbeddingOutputCreated();
193+
[y, embedding] =
194+
this.modelWithEmbeddingOutput.predict(x) as tf.Tensor[];
195+
} else {
196+
y = this.model.predict(x) as tf.Tensor;
197+
}
198+
178199
const scores = await y.data() as Float32Array;
179200
const maxIndexTensor = y.argMax(-1);
180201
const maxIndex = (await maxIndexTensor.data())[0];
@@ -201,7 +222,7 @@ export class BrowserFftSpeechCommandRecognizer implements
201222
}
202223
}
203224
if (wordDetected) {
204-
callback({scores, spectrogram});
225+
callback({scores, spectrogram, embedding});
205226
}
206227
// Trigger suppression only if the word is neither unknown or
207228
// background noise.
@@ -289,6 +310,39 @@ export class BrowserFftSpeechCommandRecognizer implements
289310
this.parameters.spectrogramDurationMillis = numFrames * frameDurationMillis;
290311
}
291312

313+
/**
314+
* Construct a two-output model that includes the following outputs:
315+
*
316+
* 1. The same softmax probability output as the original model's output
317+
* 2. The embedding, i.e., activation from the second-last dense layer of
318+
* the original model.
319+
*/
320+
protected async ensureModelWithEmbeddingOutputCreated() {
321+
if (this.modelWithEmbeddingOutput != null) {
322+
return;
323+
}
324+
await this.ensureModelLoaded();
325+
326+
// Find the second last dense layer of the original model.
327+
let secondLastDenseLayer: tf.layers.Layer;
328+
for (let i = this.model.layers.length - 2; i >= 0; --i) {
329+
if (this.model.layers[i].getClassName() === 'Dense') {
330+
secondLastDenseLayer = this.model.layers[i];
331+
break;
332+
}
333+
}
334+
if (secondLastDenseLayer == null) {
335+
throw new Error(
336+
'Failed to find second last dense layer in the original model.');
337+
}
338+
this.modelWithEmbeddingOutput = tf.model({
339+
inputs: this.model.inputs,
340+
outputs: [
341+
this.model.outputs[0], secondLastDenseLayer.output as tf.SymbolicTensor
342+
]
343+
});
344+
}
345+
292346
private warmUpModel() {
293347
tf.tidy(() => {
294348
const x = tf.zeros([1].concat(this.nonBatchInputShape));
@@ -370,15 +424,27 @@ export class BrowserFftSpeechCommandRecognizer implements
370424
* - If a `Float32Array`, must have a length divisible by the number
371425
* of elements per spectrogram, i.e.,
372426
* (# of spectrogram columns) * (# of frequency-domain points per column).
427+
* @param config Optional configuration object.
373428
* @returns Result of the recognition, with the following field:
374429
* scores:
375430
* - A `Float32Array` if there is only one input exapmle.
376431
* - An `Array` of `Float32Array`, if there are multiple input examples.
377432
*/
378-
async recognize(input: tf.Tensor|
379-
Float32Array): Promise<SpeechCommandRecognizerResult> {
433+
async recognize(input?: tf.Tensor|Float32Array, config?: RecognizeConfig):
434+
Promise<SpeechCommandRecognizerResult> {
435+
if (config == null) {
436+
config = {};
437+
}
438+
380439
await this.ensureModelLoaded();
381440

441+
if (input == null) {
442+
// If `input` is not provided, draw audio data from WebAudio and us it
443+
// for recognition.
444+
const spectrogramData = await this.recognizeOnline();
445+
input = spectrogramData.data;
446+
}
447+
382448
let numExamples: number;
383449
let inputTensor: tf.Tensor;
384450
let outTensor: tf.Tensor;
@@ -403,16 +469,49 @@ export class BrowserFftSpeechCommandRecognizer implements
403469
].concat(this.nonBatchInputShape) as [number, number, number, number]);
404470
}
405471

406-
outTensor = this.model.predict(inputTensor) as tf.Tensor;
472+
const output: SpeechCommandRecognizerResult = {scores: null};
473+
if (config.includeEmbedding) {
474+
// Optional inclusion of embedding (internal activation).
475+
await this.ensureModelWithEmbeddingOutputCreated();
476+
const outAndEmbedding =
477+
this.modelWithEmbeddingOutput.predict(inputTensor) as tf.Tensor[];
478+
outTensor = outAndEmbedding[0];
479+
output.embedding = outAndEmbedding[1];
480+
} else {
481+
outTensor = this.model.predict(inputTensor) as tf.Tensor;
482+
}
483+
407484
if (numExamples === 1) {
408-
return {scores: await outTensor.data() as Float32Array};
485+
output.scores = await outTensor.data() as Float32Array;
409486
} else {
410487
const unstacked = tf.unstack(outTensor) as tf.Tensor[];
411488
const scorePromises = unstacked.map(item => item.data());
412-
const scores = await Promise.all(scorePromises) as Float32Array[];
489+
output.scores = await Promise.all(scorePromises) as Float32Array[];
413490
tf.dispose(unstacked);
414-
return {scores};
415491
}
492+
return output;
493+
}
494+
495+
protected async recognizeOnline(): Promise<SpectrogramData> {
496+
return new Promise<SpectrogramData>((resolve, reject) => {
497+
const spectrogramCallback: SpectrogramCallback = async (x: tf.Tensor) => {
498+
resolve({
499+
data: await x.data() as Float32Array,
500+
frameSize: this.nonBatchInputShape[1],
501+
});
502+
return false;
503+
};
504+
this.audioDataExtractor = new BrowserFftFeatureExtractor({
505+
sampleRateHz: this.parameters.sampleRateHz,
506+
columnBufferLength: this.parameters.columnBufferLength,
507+
columnHopLength: this.parameters.columnBufferLength,
508+
numFramesPerSpectrogram: this.nonBatchInputShape[0],
509+
columnTruncateLength: this.nonBatchInputShape[1],
510+
suppressionTimeMillis: 0,
511+
spectrogramCallback
512+
});
513+
this.audioDataExtractor.start();
514+
});
416515
}
417516

418517
createTransfer(name: string): TransferSpeechCommandRecognizer {

speech-commands/src/browser_fft_recognizer_test.ts

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,25 @@ describeWithFlags('Browser FFT recognizer', tf.test_util.NODE_ENVS, () => {
218218
}
219219
});
220220

221+
it('Offline recognize call: includeEmbedding', async () => {
222+
setUpFakes();
223+
224+
// A batch of examples.
225+
const numExamples = 3;
226+
const spectrogram =
227+
tf.zeros([numExamples, fakeNumFrames, fakeColumnTruncateLength, 1]);
228+
const recognizer = new BrowserFftSpeechCommandRecognizer();
229+
const output =
230+
await recognizer.recognize(spectrogram, {includeEmbedding: true});
231+
expect(Array.isArray(output.scores)).toEqual(true);
232+
expect(output.scores.length).toEqual(3);
233+
for (let i = 0; i < 3; ++i) {
234+
expect((output.scores[i] as Float32Array).length).toEqual(17);
235+
}
236+
expect(output.embedding.rank).toEqual(2);
237+
expect(output.embedding.shape[0]).toEqual(numExamples);
238+
});
239+
221240
it('Offline recognize fails due to incorrect shape', async () => {
222241
setUpFakes();
223242

@@ -346,12 +365,58 @@ describeWithFlags('Browser FFT recognizer', tf.test_util.NODE_ENVS, () => {
346365
// spectrogram is not provided by default.
347366
expect(result.spectrogram).toBeUndefined();
348367

368+
// Embedding should not be included by default.
369+
expect(result.embedding).toBeUndefined();
370+
349371
if (++numCallbacksCompleted >= numCallbacksToComplete) {
350-
recognizer.stopStreaming().then(done);
372+
await recognizer.stopStreaming();
373+
done();
351374
}
352375
}, {overlapFactor: 0, invokeCallbackOnNoiseAndUnknown: true});
353376
});
354377

378+
it('streaming: overlapFactor = 0, includeEmbedding', async done => {
379+
setUpFakes();
380+
const recognizer = new BrowserFftSpeechCommandRecognizer();
381+
382+
const numCallbacksToComplete = 2;
383+
let numCallbacksCompleted = 0;
384+
const tensorCounts: number[] = [];
385+
const callbackTimestamps: number[] = [];
386+
recognizer.startStreaming(async (result: SpeechCommandRecognizerResult) => {
387+
expect((result.scores as Float32Array).length).toEqual(fakeWords.length);
388+
389+
callbackTimestamps.push(tf.util.now());
390+
if (callbackTimestamps.length > 1) {
391+
expect(
392+
callbackTimestamps[callbackTimestamps.length - 1] -
393+
callbackTimestamps[callbackTimestamps.length - 2])
394+
.toBeGreaterThanOrEqual(
395+
recognizer.params().spectrogramDurationMillis);
396+
}
397+
398+
tensorCounts.push(tf.memory().numTensors);
399+
400+
// spectrogram is not provided by default.
401+
expect(result.spectrogram).toBeUndefined();
402+
403+
// Embedding should not be included by default.
404+
expect(result.embedding.rank).toEqual(2);
405+
expect(result.embedding.shape[0]).toEqual(1);
406+
// The number of units of the hidden dense layer.
407+
expect(result.embedding.shape[1]).toEqual(4);
408+
409+
if (++numCallbacksCompleted >= numCallbacksToComplete) {
410+
await recognizer.stopStreaming();
411+
done();
412+
}
413+
}, {
414+
overlapFactor: 0,
415+
invokeCallbackOnNoiseAndUnknown: true,
416+
includeEmbedding: true
417+
});
418+
});
419+
355420
it('streaming: overlapFactor = 0.5, includeSpectrogram', async done => {
356421
setUpFakes();
357422
const recognizer = new BrowserFftSpeechCommandRecognizer();
@@ -482,6 +547,31 @@ describeWithFlags('Browser FFT recognizer', tf.test_util.NODE_ENVS, () => {
482547
expect(recognizer.isStreaming()).toEqual(false);
483548
});
484549

550+
it('Online recognize() call succeeds', async () => {
551+
setUpFakes();
552+
const recognizer = new BrowserFftSpeechCommandRecognizer();
553+
554+
for (let i = 0; i < 2; ++i) {
555+
// No-arg call: online recognition.
556+
const output = await recognizer.recognize();
557+
expect(output.scores.length).toEqual(fakeWords.length);
558+
expect(output.embedding).toBeUndefined();
559+
}
560+
});
561+
562+
it('Online recognize() call with includeEmbedding succeeds', async () => {
563+
setUpFakes();
564+
const recognizer = new BrowserFftSpeechCommandRecognizer();
565+
566+
for (let i = 0; i < 2; ++i) {
567+
// No-arg call: online recognition.
568+
const output = await recognizer.recognize(null, {includeEmbedding: true});
569+
expect(output.scores.length).toEqual(fakeWords.length);
570+
expect(output.embedding.rank).toEqual(2);
571+
expect(output.embedding.shape[0]).toEqual(1);
572+
}
573+
});
574+
485575
it('collectTransferLearningExample default transerf model', async () => {
486576
setUpFakes();
487577
const base = new BrowserFftSpeechCommandRecognizer();

0 commit comments

Comments
 (0)