@@ -20,7 +20,7 @@ import * as tf from '@tensorflow/tfjs';
20
20
// tslint:disable:max-line-length
21
21
import { BrowserFftFeatureExtractor , SpectrogramCallback } from './browser_fft_extractor' ;
22
22
import { loadMetadataJson } from './browser_fft_utils' ;
23
- import { RecognizerCallback , RecognizerParams , SpectrogramData , SpeechCommandRecognizer , SpeechCommandRecognizerResult , StreamingRecognitionConfig , TransferLearnConfig , TransferSpeechCommandRecognizer } from './types' ;
23
+ import { RecognizeConfig , RecognizerCallback , RecognizerParams , SpectrogramData , SpeechCommandRecognizer , SpeechCommandRecognizerResult , StreamingRecognitionConfig , TransferLearnConfig , TransferSpeechCommandRecognizer } from './types' ;
24
24
import { version } from './version' ;
25
25
26
26
// tslint:enable:max-line-length
@@ -47,6 +47,7 @@ export class BrowserFftSpeechCommandRecognizer implements
47
47
private readonly DEFAULT_SUPPRESSION_TIME_MILLIS = 1000 ;
48
48
49
49
model : tf . Model ;
50
+ modelWithEmbeddingOutput : tf . Model ;
50
51
readonly vocabulary : string ;
51
52
readonly parameters : RecognizerParams ;
52
53
protected words : string [ ] ;
@@ -149,15 +150,24 @@ export class BrowserFftSpeechCommandRecognizer implements
149
150
if ( config == null ) {
150
151
config = { } ;
151
152
}
152
- const probabilityThreshold =
153
+ let probabilityThreshold =
153
154
config . probabilityThreshold == null ? 0 : config . probabilityThreshold ;
155
+ if ( config . includeEmbedding ) {
156
+ // Override probability threshold to 0 if includeEmbedding is true.
157
+ probabilityThreshold = 0 ;
158
+ }
154
159
tf . util . assert (
155
160
probabilityThreshold >= 0 && probabilityThreshold <= 1 ,
156
161
`Invalid probabilityThreshold value: ${ probabilityThreshold } ` ) ;
157
- const invokeCallbackOnNoiseAndUnknown =
162
+ let invokeCallbackOnNoiseAndUnknown =
158
163
config . invokeCallbackOnNoiseAndUnknown == null ?
159
164
false :
160
165
config . invokeCallbackOnNoiseAndUnknown ;
166
+ if ( config . includeEmbedding ) {
167
+ // Override invokeCallbackOnNoiseAndUnknown threshold to true if
168
+ // includeEmbedding is true.
169
+ invokeCallbackOnNoiseAndUnknown = true ;
170
+ }
161
171
162
172
if ( config . suppressionTimeMillis < 0 ) {
163
173
throw new Error (
@@ -174,7 +184,18 @@ export class BrowserFftSpeechCommandRecognizer implements
174
184
Math . round ( this . FFT_SIZE * ( 1 - overlapFactor ) ) ;
175
185
176
186
const spectrogramCallback : SpectrogramCallback = async ( x : tf . Tensor ) => {
177
- const y = tf . tidy ( ( ) => this . model . predict ( x ) as tf . Tensor ) ;
187
+ await this . ensureModelWithEmbeddingOutputCreated ( ) ;
188
+
189
+ let y : tf . Tensor ;
190
+ let embedding : tf . Tensor ;
191
+ if ( config . includeEmbedding ) {
192
+ await this . ensureModelWithEmbeddingOutputCreated ( ) ;
193
+ [ y , embedding ] =
194
+ this . modelWithEmbeddingOutput . predict ( x ) as tf . Tensor [ ] ;
195
+ } else {
196
+ y = this . model . predict ( x ) as tf . Tensor ;
197
+ }
198
+
178
199
const scores = await y . data ( ) as Float32Array ;
179
200
const maxIndexTensor = y . argMax ( - 1 ) ;
180
201
const maxIndex = ( await maxIndexTensor . data ( ) ) [ 0 ] ;
@@ -201,7 +222,7 @@ export class BrowserFftSpeechCommandRecognizer implements
201
222
}
202
223
}
203
224
if ( wordDetected ) {
204
- callback ( { scores, spectrogram} ) ;
225
+ callback ( { scores, spectrogram, embedding } ) ;
205
226
}
206
227
// Trigger suppression only if the word is neither unknown or
207
228
// background noise.
@@ -289,6 +310,39 @@ export class BrowserFftSpeechCommandRecognizer implements
289
310
this . parameters . spectrogramDurationMillis = numFrames * frameDurationMillis ;
290
311
}
291
312
313
+ /**
314
+ * Construct a two-output model that includes the following outputs:
315
+ *
316
+ * 1. The same softmax probability output as the original model's output
317
+ * 2. The embedding, i.e., activation from the second-last dense layer of
318
+ * the original model.
319
+ */
320
+ protected async ensureModelWithEmbeddingOutputCreated ( ) {
321
+ if ( this . modelWithEmbeddingOutput != null ) {
322
+ return ;
323
+ }
324
+ await this . ensureModelLoaded ( ) ;
325
+
326
+ // Find the second last dense layer of the original model.
327
+ let secondLastDenseLayer : tf . layers . Layer ;
328
+ for ( let i = this . model . layers . length - 2 ; i >= 0 ; -- i ) {
329
+ if ( this . model . layers [ i ] . getClassName ( ) === 'Dense' ) {
330
+ secondLastDenseLayer = this . model . layers [ i ] ;
331
+ break ;
332
+ }
333
+ }
334
+ if ( secondLastDenseLayer == null ) {
335
+ throw new Error (
336
+ 'Failed to find second last dense layer in the original model.' ) ;
337
+ }
338
+ this . modelWithEmbeddingOutput = tf . model ( {
339
+ inputs : this . model . inputs ,
340
+ outputs : [
341
+ this . model . outputs [ 0 ] , secondLastDenseLayer . output as tf . SymbolicTensor
342
+ ]
343
+ } ) ;
344
+ }
345
+
292
346
private warmUpModel ( ) {
293
347
tf . tidy ( ( ) => {
294
348
const x = tf . zeros ( [ 1 ] . concat ( this . nonBatchInputShape ) ) ;
@@ -370,15 +424,27 @@ export class BrowserFftSpeechCommandRecognizer implements
370
424
* - If a `Float32Array`, must have a length divisible by the number
371
425
* of elements per spectrogram, i.e.,
372
426
* (# of spectrogram columns) * (# of frequency-domain points per column).
427
+ * @param config Optional configuration object.
373
428
* @returns Result of the recognition, with the following field:
374
429
* scores:
375
430
* - A `Float32Array` if there is only one input exapmle.
376
431
* - An `Array` of `Float32Array`, if there are multiple input examples.
377
432
*/
378
- async recognize ( input : tf . Tensor |
379
- Float32Array ) : Promise < SpeechCommandRecognizerResult > {
433
+ async recognize ( input ?: tf . Tensor | Float32Array , config ?: RecognizeConfig ) :
434
+ Promise < SpeechCommandRecognizerResult > {
435
+ if ( config == null ) {
436
+ config = { } ;
437
+ }
438
+
380
439
await this . ensureModelLoaded ( ) ;
381
440
441
+ if ( input == null ) {
442
+ // If `input` is not provided, draw audio data from WebAudio and us it
443
+ // for recognition.
444
+ const spectrogramData = await this . recognizeOnline ( ) ;
445
+ input = spectrogramData . data ;
446
+ }
447
+
382
448
let numExamples : number ;
383
449
let inputTensor : tf . Tensor ;
384
450
let outTensor : tf . Tensor ;
@@ -403,16 +469,49 @@ export class BrowserFftSpeechCommandRecognizer implements
403
469
] . concat ( this . nonBatchInputShape ) as [ number , number , number , number ] ) ;
404
470
}
405
471
406
- outTensor = this . model . predict ( inputTensor ) as tf . Tensor ;
472
+ const output : SpeechCommandRecognizerResult = { scores : null } ;
473
+ if ( config . includeEmbedding ) {
474
+ // Optional inclusion of embedding (internal activation).
475
+ await this . ensureModelWithEmbeddingOutputCreated ( ) ;
476
+ const outAndEmbedding =
477
+ this . modelWithEmbeddingOutput . predict ( inputTensor ) as tf . Tensor [ ] ;
478
+ outTensor = outAndEmbedding [ 0 ] ;
479
+ output . embedding = outAndEmbedding [ 1 ] ;
480
+ } else {
481
+ outTensor = this . model . predict ( inputTensor ) as tf . Tensor ;
482
+ }
483
+
407
484
if ( numExamples === 1 ) {
408
- return { scores : await outTensor . data ( ) as Float32Array } ;
485
+ output . scores = await outTensor . data ( ) as Float32Array ;
409
486
} else {
410
487
const unstacked = tf . unstack ( outTensor ) as tf . Tensor [ ] ;
411
488
const scorePromises = unstacked . map ( item => item . data ( ) ) ;
412
- const scores = await Promise . all ( scorePromises ) as Float32Array [ ] ;
489
+ output . scores = await Promise . all ( scorePromises ) as Float32Array [ ] ;
413
490
tf . dispose ( unstacked ) ;
414
- return { scores} ;
415
491
}
492
+ return output ;
493
+ }
494
+
495
+ protected async recognizeOnline ( ) : Promise < SpectrogramData > {
496
+ return new Promise < SpectrogramData > ( ( resolve , reject ) => {
497
+ const spectrogramCallback : SpectrogramCallback = async ( x : tf . Tensor ) => {
498
+ resolve ( {
499
+ data : await x . data ( ) as Float32Array ,
500
+ frameSize : this . nonBatchInputShape [ 1 ] ,
501
+ } ) ;
502
+ return false ;
503
+ } ;
504
+ this . audioDataExtractor = new BrowserFftFeatureExtractor ( {
505
+ sampleRateHz : this . parameters . sampleRateHz ,
506
+ columnBufferLength : this . parameters . columnBufferLength ,
507
+ columnHopLength : this . parameters . columnBufferLength ,
508
+ numFramesPerSpectrogram : this . nonBatchInputShape [ 0 ] ,
509
+ columnTruncateLength : this . nonBatchInputShape [ 1 ] ,
510
+ suppressionTimeMillis : 0 ,
511
+ spectrogramCallback
512
+ } ) ;
513
+ this . audioDataExtractor . start ( ) ;
514
+ } ) ;
416
515
}
417
516
418
517
createTransfer ( name : string ) : TransferSpeechCommandRecognizer {
0 commit comments