2525from einops import rearrange
2626from torch import Tensor , nn
2727
28- from .utils import make_non_pad_mask , mask_to_bias , onnx2torch
28+ from .utils import make_non_pad_mask , mask_to_bias , onnx2torch , merge_tokenized_segments
2929
3030
3131@dataclass
@@ -236,7 +236,7 @@ def preprocess(self, x: Tensor) -> Tensor:
236236
237237 @torch .inference_mode ()
238238 def quantize (self , x : Tensor ) -> Tensor :
239- embed = self .embed .t ()
239+ embed = self .embed .t (). to ( x . dtype )
240240 dist = - (x .pow (2 ).sum (1 , keepdim = True ) - 2 * x @ embed +
241241 embed .pow (2 ).sum (0 , keepdim = True ))
242242 embed_ind = dist .max (dim = - 1 ).indices
@@ -287,7 +287,7 @@ def codebook(self):
287287
288288 @torch .inference_mode ()
289289 def encode (self , x : Tensor ) -> Tensor :
290- x = F .normalize (x , p = 2 , dim = - 1 )
290+ x = F .normalize (x . float () , p = 2 , dim = - 1 )
291291 embed_in = self ._codebook .encode (x )
292292 return embed_in
293293
@@ -306,6 +306,7 @@ class S3Tokenizer(nn.Module):
306306
307307 def __init__ (self , name : str , config : ModelConfig = ModelConfig ()):
308308 super ().__init__ ()
309+ self .name = name # Store model name for token_rate determination
309310 self .config = config
310311 self .encoder = AudioEncoder (
311312 self .config .n_mels ,
@@ -324,9 +325,209 @@ def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
324325
325326 @torch .inference_mode ()
326327 def quantize (self , mel : Tensor , mel_len : Tensor ) -> Tuple [Tensor , Tensor ]:
327- hidden , code_len = self .encoder (mel , mel_len )
328- code = self .quantizer .encode (hidden )
329- return code , code_len
328+ """
329+ Quantize mel spectrogram to tokens, with automatic long audio handling.
330+
331+ Args:
332+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
333+ mel_len: mel length tensor, shape (batch_size,)
334+
335+ Returns:
336+ code: quantized tokens, shape (batch_size, T')
337+ code_len: token length, shape (batch_size,)
338+ """
339+ # Check if any audio in the batch exceeds 30 seconds
340+ # Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
341+ max_frames = 3000
342+
343+ # Check which samples are long audio
344+ long_audio_mask = mel_len > max_frames
345+
346+ if long_audio_mask .any ():
347+ # Has long audio - need special processing
348+ return self ._quantize_mixed_batch (mel , mel_len , long_audio_mask ,
349+ max_frames )
350+ else :
351+ # All short audio - use original method
352+ hidden , code_len = self .encoder (mel , mel_len )
353+ code = self .quantizer .encode (hidden )
354+ return code , code_len
355+
356+ @torch .inference_mode ()
357+ def _quantize_mixed_batch (self , mel : Tensor , mel_len : Tensor ,
358+ long_audio_mask : Tensor ,
359+ max_frames : int ) -> Tuple [Tensor , Tensor ]:
360+ """
361+ Handle mixed batch with both short and long audio using unified batch processing.
362+
363+ Args:
364+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
365+ mel_len: mel length tensor, shape (batch_size,)
366+ long_audio_mask: boolean mask for long audio, shape (batch_size,)
367+ max_frames: maximum frames for short audio
368+
369+ Returns:
370+ code: quantized tokens, shape (batch_size, T')
371+ code_len: token length, shape (batch_size,)
372+ """
373+ batch_size = mel .size (0 )
374+
375+ # Parameters for sliding window
376+ sample_rate = 16000
377+ hop_length = 160 # Default hop length for mel spectrogram
378+ window_size = 30 # seconds
379+ overlap = 4 # seconds
380+
381+ # Calculate frame-based parameters
382+ frames_per_window = window_size * sample_rate // hop_length # 3000 frames
383+ frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
384+ frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
385+
386+ # Collect all segments to process (including short and long audio segments)
387+ all_segments = []
388+ all_segments_len = []
389+ segment_info = [
390+ ] # Record which audio each segment belongs to and whether it's long audio
391+
392+ # Process all audio in the batch
393+ for batch_idx in range (batch_size ):
394+ audio_mel = mel [batch_idx ]
395+ audio_mel_len = mel_len [batch_idx ]
396+ is_long_audio = long_audio_mask [batch_idx ].item ()
397+
398+ if not is_long_audio :
399+ # Short audio: process directly as a single segment
400+ segment = audio_mel [:, :audio_mel_len ]
401+ seg_len = audio_mel_len .item ()
402+
403+ # Pad to max_frames if necessary
404+ if seg_len < frames_per_window :
405+ pad_size = frames_per_window - seg_len
406+ segment = F .pad (segment , (0 , pad_size ))
407+
408+ all_segments .append (segment )
409+ all_segments_len .append (
410+ torch .tensor (seg_len , device = mel .device ))
411+ segment_info .append ({
412+ 'batch_idx' : batch_idx ,
413+ 'is_long_audio' : False ,
414+ 'segment_idx' : 0 ,
415+ 'total_segments' : 1
416+ })
417+ else :
418+ # Long audio: split into multiple segments
419+ start = 0
420+ segment_idx = 0
421+ while start < audio_mel_len :
422+ end = min (start + frames_per_window , audio_mel_len )
423+ segment = audio_mel [:, start :end ]
424+
425+ seg_len = segment .size (1 )
426+ # Pad if necessary
427+ if seg_len < frames_per_window :
428+ pad_size = frames_per_window - seg_len
429+ segment = F .pad (segment , (0 , pad_size ))
430+
431+ all_segments .append (segment )
432+ all_segments_len .append (
433+ torch .tensor (seg_len , device = mel .device ))
434+ segment_info .append ({
435+ 'batch_idx' : batch_idx ,
436+ 'is_long_audio' : True ,
437+ 'segment_idx' : segment_idx ,
438+ 'total_segments' : None # Will be filled later
439+ })
440+
441+ segment_idx += 1
442+ start += frames_per_stride
443+
444+ # Update total_segments info
445+ total_segments = segment_idx
446+ for info in segment_info :
447+ if info ['batch_idx' ] == batch_idx and info ['is_long_audio' ]:
448+ info ['total_segments' ] = total_segments
449+
450+ if not all_segments :
451+ # Fallback if no segments
452+ return torch .zeros (batch_size ,
453+ 0 ,
454+ dtype = torch .long ,
455+ device = mel .device ), torch .zeros (
456+ batch_size ,
457+ dtype = torch .long ,
458+ device = mel .device )
459+
460+ # Unified batch processing for all segments
461+ unified_batch_mel = torch .stack (all_segments )
462+ unified_batch_lens = torch .stack (all_segments_len )
463+
464+ # Process all segments at once
465+ hidden , code_len = self .encoder (unified_batch_mel , unified_batch_lens )
466+ codes = self .quantizer .encode (hidden )
467+
468+ # Reorganize results based on segment_info
469+ results = {} # batch_idx -> (code_tensor, code_len)
470+
471+ for seg_idx , info in enumerate (segment_info ):
472+ batch_idx = info ['batch_idx' ]
473+ is_long_audio = info ['is_long_audio' ]
474+ segment_idx = info ['segment_idx' ]
475+
476+ # Get codes for current segment
477+ segment_code = codes [
478+ seg_idx , :code_len [seg_idx ].item ()].cpu ().numpy ().tolist ()
479+
480+ if not is_long_audio :
481+ # Short audio: use directly
482+ code_tensor = torch .tensor (segment_code ,
483+ dtype = torch .long ,
484+ device = mel .device )
485+ results [batch_idx ] = (code_tensor , len (segment_code ))
486+ else :
487+ # Long audio: collect all segments
488+ if batch_idx not in results :
489+ results [batch_idx ] = []
490+ results [batch_idx ].append (segment_code )
491+
492+ # Process long audio segment merging
493+ for batch_idx in range (batch_size ):
494+ if long_audio_mask [batch_idx ].item ():
495+ # Merge long audio segments
496+ audio_codes = results [batch_idx ]
497+
498+ # Determine token rate based on model name
499+ if hasattr (self ,
500+ 'name' ) and self .name == "speech_tokenizer_v1" :
501+ token_rate = 50
502+ else :
503+ token_rate = 25
504+
505+ merged_codes = merge_tokenized_segments (audio_codes ,
506+ overlap = overlap ,
507+ token_rate = token_rate )
508+
509+ # Convert to tensor
510+ merged_codes_tensor = torch .tensor (merged_codes ,
511+ dtype = torch .long ,
512+ device = mel .device )
513+ results [batch_idx ] = (merged_codes_tensor , len (merged_codes ))
514+
515+ # Construct final output
516+ max_code_len = max (code_info [1 ] for code_info in results .values ())
517+
518+ output_codes = torch .zeros (batch_size ,
519+ max_code_len ,
520+ dtype = torch .long ,
521+ device = mel .device )
522+ output_codes_len = torch .zeros (batch_size ,
523+ dtype = torch .long ,
524+ device = mel .device )
525+
526+ for batch_idx , (code_tensor , code_len ) in results .items ():
527+ output_codes [batch_idx , :code_len ] = code_tensor
528+ output_codes_len [batch_idx ] = code_len
529+
530+ return output_codes , output_codes_len
330531
331532 @property
332533 def device (self ):
0 commit comments