20
20
*/
21
21
22
22
import * as tf from '@tensorflow/tfjs' ;
23
- import { getAudioContextConstructor , getAudioMediaStream , normalize } from './browser_fft_utils' ;
23
+ import { getAudioContextConstructor , getAudioMediaStream } from './browser_fft_utils' ;
24
24
import { FeatureExtractor , RecognizerParams } from './types' ;
25
25
26
26
export type SpectrogramCallback = ( x : tf . Tensor ) => Promise < boolean > ;
@@ -77,7 +77,7 @@ export interface BrowserFftFeatureExtractorConfig extends RecognizerParams {
77
77
*/
78
78
export class BrowserFftFeatureExtractor implements FeatureExtractor {
79
79
// Number of frames (i.e., columns) per spectrogram used for classification.
80
- readonly numFramesPerSpectrogram : number ;
80
+ readonly numFrames : number ;
81
81
82
82
// Audio sampling rate in Hz.
83
83
readonly sampleRateHz : number ;
@@ -92,22 +92,16 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
92
92
// consecutive spectrograms and the length of each individual spectrogram.
93
93
readonly overlapFactor : number ;
94
94
95
- protected readonly spectrogramCallback : SpectrogramCallback ;
95
+ private readonly spectrogramCallback : SpectrogramCallback ;
96
96
97
97
private stream : MediaStream ;
98
98
// tslint:disable-next-line:no-any
99
99
private audioContextConstructor : any ;
100
100
private audioContext : AudioContext ;
101
101
private analyser : AnalyserNode ;
102
-
103
102
private tracker : Tracker ;
104
-
105
- private readonly ROTATING_BUFFER_SIZE_MULTIPLIER = 2 ;
106
103
private freqData : Float32Array ;
107
- private rotatingBufferNumFrames : number ;
108
- private rotatingBuffer : Float32Array ;
109
-
110
- private frameCount : number ;
104
+ private freqDataQueue : Float32Array [ ] ;
111
105
// tslint:disable-next-line:no-any
112
106
private frameIntervalTask : any ;
113
107
private frameDurationMillis : number ;
@@ -144,7 +138,7 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
144
138
this . suppressionTimeMillis = config . suppressionTimeMillis ;
145
139
146
140
this . spectrogramCallback = config . spectrogramCallback ;
147
- this . numFramesPerSpectrogram = config . numFramesPerSpectrogram ;
141
+ this . numFrames = config . numFramesPerSpectrogram ;
148
142
this . sampleRateHz = config . sampleRateHz || 44100 ;
149
143
this . fftSize = config . fftSize || 1024 ;
150
144
this . frameDurationMillis = this . fftSize / this . sampleRateHz * 1e3 ;
@@ -165,7 +159,7 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
165
159
this . audioContextConstructor = getAudioContextConstructor ( ) ;
166
160
}
167
161
168
- async start ( samples ?: Float32Array ) : Promise < Float32Array [ ] | void > {
162
+ async start ( ) : Promise < Float32Array [ ] | void > {
169
163
if ( this . frameIntervalTask != null ) {
170
164
throw new Error (
171
165
'Cannot start already-started BrowserFftFeatureExtractor' ) ;
@@ -184,18 +178,11 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
184
178
this . analyser . fftSize = this . fftSize * 2 ;
185
179
this . analyser . smoothingTimeConstant = 0.0 ;
186
180
streamSource . connect ( this . analyser ) ;
187
-
181
+ // Reset the queue.
182
+ this . freqDataQueue = [ ] ;
188
183
this . freqData = new Float32Array ( this . fftSize ) ;
189
- this . rotatingBufferNumFrames =
190
- this . numFramesPerSpectrogram * this . ROTATING_BUFFER_SIZE_MULTIPLIER ;
191
- const rotatingBufferSize =
192
- this . columnTruncateLength * this . rotatingBufferNumFrames ;
193
- this . rotatingBuffer = new Float32Array ( rotatingBufferSize ) ;
194
-
195
- this . frameCount = 0 ;
196
-
197
- const period = Math . max (
198
- 1 , Math . round ( this . numFramesPerSpectrogram * ( 1 - this . overlapFactor ) ) ) ;
184
+ const period =
185
+ Math . max ( 1 , Math . round ( this . numFrames * ( 1 - this . overlapFactor ) ) ) ;
199
186
this . tracker = new Tracker (
200
187
period ,
201
188
Math . round ( this . suppressionTimeMillis / this . frameDurationMillis ) ) ;
@@ -209,20 +196,16 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
209
196
return ;
210
197
}
211
198
212
- const freqDataSlice = this . freqData . slice ( 0 , this . columnTruncateLength ) ;
213
- const bufferPos = this . frameCount % this . rotatingBufferNumFrames ;
214
- this . rotatingBuffer . set (
215
- freqDataSlice , bufferPos * this . columnTruncateLength ) ;
216
- this . frameCount ++ ;
217
-
199
+ this . freqDataQueue . push ( this . freqData . slice ( 0 , this . columnTruncateLength ) ) ;
200
+ if ( this . freqDataQueue . length > this . numFrames ) {
201
+ // Drop the oldest frame (least recent).
202
+ this . freqDataQueue . shift ( ) ;
203
+ }
218
204
const shouldFire = this . tracker . tick ( ) ;
219
205
if ( shouldFire ) {
220
- const freqData = getFrequencyDataFromRotatingBuffer (
221
- this . rotatingBuffer , this . numFramesPerSpectrogram ,
222
- this . columnTruncateLength ,
223
- this . frameCount - this . numFramesPerSpectrogram ) ;
206
+ const freqData = flattenQueue ( this . freqDataQueue ) ;
224
207
const inputTensor = getInputTensorFromFrequencyData (
225
- freqData , this . numFramesPerSpectrogram , this . columnTruncateLength ) ;
208
+ freqData , [ 1 , this . numFrames , this . columnTruncateLength , 1 ] ) ;
226
209
const shouldRest = await this . spectrogramCallback ( inputTensor ) ;
227
210
if ( shouldRest ) {
228
211
this . tracker . suppress ( ) ;
@@ -240,6 +223,9 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
240
223
this . frameIntervalTask = null ;
241
224
this . analyser . disconnect ( ) ;
242
225
this . audioContext . close ( ) ;
226
+ if ( this . stream != null && this . stream . getTracks ( ) . length > 0 ) {
227
+ this . stream . getTracks ( ) [ 0 ] . stop ( ) ;
228
+ }
243
229
}
244
230
245
231
setConfig ( params : RecognizerParams ) {
@@ -255,39 +241,19 @@ export class BrowserFftFeatureExtractor implements FeatureExtractor {
255
241
}
256
242
}
257
243
258
- export function getFrequencyDataFromRotatingBuffer (
259
- rotatingBuffer : Float32Array , numFrames : number , fftLength : number ,
260
- frameCount : number ) : Float32Array {
261
- const size = numFrames * fftLength ;
262
- const freqData = new Float32Array ( size ) ;
263
-
264
- const rotatingBufferSize = rotatingBuffer . length ;
265
- const rotatingBufferNumFrames = rotatingBufferSize / fftLength ;
266
- while ( frameCount < 0 ) {
267
- frameCount += rotatingBufferNumFrames ;
268
- }
269
- const indexBegin = ( frameCount % rotatingBufferNumFrames ) * fftLength ;
270
- const indexEnd = indexBegin + size ;
271
-
272
- for ( let i = indexBegin ; i < indexEnd ; ++ i ) {
273
- freqData [ i - indexBegin ] = rotatingBuffer [ i % rotatingBufferSize ] ;
274
- }
244
+ export function flattenQueue ( queue : Float32Array [ ] ) : Float32Array {
245
+ const frameSize = queue [ 0 ] . length ;
246
+ const freqData = new Float32Array ( queue . length * frameSize ) ;
247
+ queue . forEach ( ( data , i ) => freqData . set ( data , i * frameSize ) ) ;
275
248
return freqData ;
276
249
}
277
250
278
251
export function getInputTensorFromFrequencyData (
279
- freqData : Float32Array , numFrames : number , fftLength : number ,
280
- toNormalize = true ) : tf . Tensor {
281
- return tf . tidy ( ( ) => {
282
- const size = freqData . length ;
283
- const tensorBuffer = tf . buffer ( [ size ] ) ;
284
- for ( let i = 0 ; i < freqData . length ; ++ i ) {
285
- tensorBuffer . set ( freqData [ i ] , i ) ;
286
- }
287
- const output =
288
- tensorBuffer . toTensor ( ) . reshape ( [ 1 , numFrames , fftLength , 1 ] ) ;
289
- return toNormalize ? normalize ( output ) : output ;
290
- } ) ;
252
+ freqData : Float32Array , shape : number [ ] ) : tf . Tensor {
253
+ const vals = new Float32Array ( tf . util . sizeFromShape ( shape ) ) ;
254
+ // If the data is less than the output shape, the rest is padded with zeros.
255
+ vals . set ( freqData , vals . length - freqData . length ) ;
256
+ return tf . tensor ( vals , shape ) ;
291
257
}
292
258
293
259
/**
0 commit comments