Skip to content

Commit f585f9e

Browse files
Add type hints and docstrings across modules
1 parent 28af5fa commit f585f9e

File tree

9 files changed

+839
-260
lines changed

9 files changed

+839
-260
lines changed

stellascript/audio/capture.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,75 @@
11
# stellascript/audio/capture.py
22

3-
import pyaudio
3+
"""
4+
Handles audio capture from the microphone using PyAudio.
5+
"""
6+
7+
import threading
48
from contextlib import contextmanager
9+
from typing import Callable, Generator, Optional
10+
11+
import pyaudio
12+
513
from ..logging_config import get_logger
614

715
logger = get_logger(__name__)
816

17+
918
class AudioCapture:
10-
def __init__(self, format, channels, rate, chunk):
11-
self.format_str = format
12-
self.format = self._get_pyaudio_format(format)
13-
self.channels = channels
14-
self.rate = rate
15-
self.chunk = chunk
16-
self.pyaudio_instance = None
17-
self.stream = None
18-
19-
def _get_pyaudio_format(self, format_str):
19+
"""
20+
A class to manage audio recording from the microphone.
21+
22+
This class provides a context manager to handle the lifecycle of a PyAudio
23+
stream, ensuring that resources are properly opened and closed.
24+
"""
25+
26+
def __init__(self, format: str, channels: int, rate: int, chunk: int) -> None:
27+
"""
28+
Initializes the AudioCapture instance.
29+
30+
Args:
31+
format (str): The audio format string (e.g., "paFloat32").
32+
channels (int): The number of audio channels.
33+
rate (int): The sampling rate in Hz.
34+
chunk (int): The number of frames per buffer.
35+
"""
36+
self.format_str: str = format
37+
self.format: int = self._get_pyaudio_format(format)
38+
self.channels: int = channels
39+
self.rate: int = rate
40+
self.chunk: int = chunk
41+
self.pyaudio_instance: Optional[pyaudio.PyAudio] = None
42+
self.stream: Optional[pyaudio.Stream] = None
43+
44+
def _get_pyaudio_format(self, format_str: str) -> int:
45+
"""
46+
Converts a format string to a PyAudio format constant.
47+
48+
Args:
49+
format_str (str): The string representation of the format.
50+
51+
Returns:
52+
int: The corresponding PyAudio format constant.
53+
54+
Raises:
55+
ValueError: If the format string is not supported.
56+
"""
2057
if format_str == "paFloat32":
2158
return pyaudio.paFloat32
2259
# Add other formats if needed
2360
raise ValueError(f"Unsupported audio format: {format_str}")
2461

2562
@contextmanager
26-
def audio_stream(self, callback):
63+
def audio_stream(self, callback: Callable) -> Generator[Optional[pyaudio.Stream], None, None]:
64+
"""
65+
A context manager for opening and managing a PyAudio stream.
66+
67+
Args:
68+
callback (Callable): The callback function to process audio chunks.
69+
70+
Yields:
71+
Optional[pyaudio.Stream]: The PyAudio stream object.
72+
"""
2773
self.pyaudio_instance = pyaudio.PyAudio()
2874
try:
2975
self.stream = self.pyaudio_instance.open(
@@ -40,33 +86,30 @@ def audio_stream(self, callback):
4086
if self.stream:
4187
try:
4288
if self.stream.is_active():
43-
# Utiliser stop_stream avec gestion du timeout
44-
import threading
45-
46-
# Pass stream object as an argument to make it explicit for Pylance
47-
def force_stop(stream_to_stop):
89+
# Use stop_stream with timeout management
90+
def force_stop(stream_to_stop: pyaudio.Stream) -> None:
4891
try:
4992
if stream_to_stop:
5093
stream_to_stop.stop_stream()
5194
except Exception:
5295
pass
53-
54-
# Lancer l'arrêt dans un thread avec timeout
96+
97+
# Run the stop in a thread with a timeout
5598
stop_thread = threading.Thread(target=force_stop, args=(self.stream,), daemon=True)
5699
stop_thread.start()
57-
stop_thread.join(timeout=0.2) # Attendre max 200ms
58-
59-
# Si le thread n'a pas fini, on continue quand même
100+
stop_thread.join(timeout=0.2) # Wait max 200ms
101+
102+
# If the thread is still running, continue anyway
60103
if stop_thread.is_alive():
61104
logger.warning("Stream stop timed out, continuing anyway")
62105
except Exception:
63106
pass
64-
107+
65108
try:
66109
self.stream.close()
67110
except Exception:
68111
pass
69-
112+
70113
if self.pyaudio_instance:
71114
try:
72115
self.pyaudio_instance.terminate()

stellascript/audio/enhancement.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,57 @@
11
# stellascript/audio/enhancement.py
22

3+
"""
4+
Handles audio enhancement using various methods like DeepFilterNet and Demucs.
5+
"""
6+
37
import warnings
8+
from typing import Any, Optional
9+
410
import numpy as np
511
import torch
612
import torchaudio
13+
714
from ..logging_config import get_logger
815

916
logger = get_logger(__name__)
1017

18+
1119
class AudioEnhancer:
12-
def __init__(self, enhancement_method, device, rate):
13-
self.enhancement_method = enhancement_method
14-
self.device = device
15-
self.rate = rate
16-
self.demucs_model = None
17-
self.df_model = None
18-
self.df_state = None
19-
20-
def apply(self, audio_data, is_live=False):
21-
"""Apply selected audio enhancement method."""
20+
"""
21+
A class to apply audio enhancement techniques to audio data.
22+
23+
This class supports multiple enhancement methods and handles the loading
24+
of the necessary models.
25+
"""
26+
27+
def __init__(self, enhancement_method: str, device: torch.device, rate: int) -> None:
28+
"""
29+
Initializes the AudioEnhancer.
30+
31+
Args:
32+
enhancement_method (str): The enhancement method to use ('none',
33+
'deepfilternet', 'demucs').
34+
device (torch.device): The device to run the models on (CPU or CUDA).
35+
rate (int): The sample rate of the input audio.
36+
"""
37+
self.enhancement_method: str = enhancement_method
38+
self.device: torch.device = device
39+
self.rate: int = rate
40+
self.demucs_model: Optional[Any] = None
41+
self.df_model: Optional[Any] = None
42+
self.df_state: Optional[Any] = None
43+
44+
def apply(self, audio_data: np.ndarray, is_live: bool = False) -> np.ndarray:
45+
"""
46+
Apply the selected audio enhancement method.
47+
48+
Args:
49+
audio_data (np.ndarray): The input audio data as a NumPy array.
50+
is_live (bool): Flag indicating if the processing is for a live stream.
51+
52+
Returns:
53+
np.ndarray: The enhanced audio data.
54+
"""
2255
if self.enhancement_method == "none":
2356
return audio_data
2457

stellascript/cli.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,18 @@
77

88
logger = get_logger(__name__)
99

10-
def parse_args():
11-
"""Parses command-line arguments."""
10+
def parse_args() -> argparse.Namespace:
11+
"""
12+
Parses command-line arguments for the Stellascript application.
13+
14+
This function sets up an ArgumentParser to handle various command-line options
15+
for transcription, including language, model selection, input file,
16+
diarization, and audio enhancement. It also includes argument validation
17+
to ensure compatibility between different options.
18+
19+
Returns:
20+
argparse.Namespace: An object containing the parsed command-line arguments.
21+
"""
1222
parser = argparse.ArgumentParser(
1323
description="Transcribe audio live from microphone or from a file."
1424
)
@@ -92,8 +102,25 @@ def parse_args():
92102
validate_args(args, parser)
93103
return args
94104

95-
def validate_args(args, parser):
96-
"""Validates parsed arguments."""
105+
106+
def validate_args(args: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
107+
"""
108+
Validates the parsed command-line arguments to ensure they are consistent.
109+
110+
This function checks for various invalid combinations of arguments, such as:
111+
- Using speaker count constraints in live mode.
112+
- Incompatible diarization and transcription modes.
113+
- Misuse of the similarity threshold with certain diarization methods.
114+
- Conflicting arguments for speaker count and similarity threshold.
115+
116+
Args:
117+
args (argparse.Namespace): The parsed command-line arguments.
118+
parser (argparse.ArgumentParser): The argument parser, used to report errors.
119+
120+
Raises:
121+
SystemExit: If an invalid combination of arguments is found, the program
122+
exits with an error message.
123+
"""
97124
if (args.min_speakers is not None or args.max_speakers is not None) and not args.file:
98125
parser.error("--min-speakers and --max-speakers can only be used in file mode (--file).")
99126

stellascript/config.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,89 @@
11
# stellascript/config.py
22

3+
"""
4+
Configuration settings for the Stellascript application.
5+
6+
This module defines various constants that control the behavior of the audio
7+
processing, transcription, and diarization pipeline. These settings are
8+
organized into sections for clarity and can be tuned to optimize performance
9+
for different use cases.
10+
11+
Attributes:
12+
FORMAT (str): The audio format used for recording, corresponding to PyAudio's
13+
`paFloat32`.
14+
CHANNELS (int): The number of audio channels (1 for mono).
15+
RATE (int): The sampling rate in Hz (16000 Hz is standard for speech).
16+
CHUNK (int): The number of samples per buffer, used for VAD processing.
17+
18+
TRANSCRIPTION_MAX_BUFFER_DURATION (float): The maximum duration of the audio
19+
buffer for transcription in
20+
seconds.
21+
SUBTITLE_MAX_BUFFER_DURATION (float): The maximum duration of the audio
22+
buffer for subtitle generation in
23+
seconds.
24+
VAD_SPEECH_THRESHOLD (float): The sensitivity threshold for the Voice
25+
Activity Detection (VAD).
26+
VAD_SILENCE_DURATION_S (float): The duration of silence in seconds that
27+
triggers a segment split.
28+
VAD_MIN_SPEECH_DURATION_S (float): The minimum duration of speech in seconds
29+
to be considered a valid segment.
30+
31+
SUBTITLE_MAX_LENGTH (int): The maximum number of characters per subtitle line.
32+
SUBTITLE_MAX_DURATION_S (float): The maximum duration of a single subtitle
33+
line in seconds.
34+
SUBTITLE_MAX_SILENCE_S (float): The maximum duration of silence to tolerate
35+
before creating a new subtitle line.
36+
37+
MAX_MERGE_GAP_S (float): The maximum gap of silence in seconds between two
38+
speech segments to be merged into one.
39+
40+
TARGET_CHUNK_DURATION_S (float): The target duration for audio chunks when
41+
processing a file.
42+
MAX_CHUNK_DURATION_S (float): The maximum allowed duration for an audio chunk.
43+
MIN_SILENCE_GAP_S (float): The minimum duration of silence to be considered a
44+
gap for chunking.
45+
46+
TRANSCRIPTION_PADDING_S (float): The duration of silence padding added to
47+
audio segments before transcription.
48+
49+
MODELS (list[str]): A list of available Whisper models for transcription.
50+
"""
51+
52+
from typing import List
53+
354
# Audio Configuration
4-
FORMAT = "paFloat32" # Corresponds to pyaudio.paFloat32
5-
CHANNELS = 1
6-
RATE = 16000
7-
CHUNK = 512 # For VAD, 512 samples = 32ms at 16kHz
55+
FORMAT: str = "paFloat32" # Corresponds to pyaudio.paFloat32
56+
CHANNELS: int = 1
57+
RATE: int = 16000
58+
CHUNK: int = 512 # For VAD, 512 samples = 32ms at 16kHz
859

960
# Transcription Mode Buffering
10-
TRANSCRIPTION_MAX_BUFFER_DURATION = 75.0 # 1min15s
61+
TRANSCRIPTION_MAX_BUFFER_DURATION: float = 75.0 # 1min15s
1162

1263
# Subtitle Mode Buffering & VAD
13-
SUBTITLE_MAX_BUFFER_DURATION = 15.0 # 15s for real-time response
14-
VAD_SPEECH_THRESHOLD = 0.4 # Lower threshold for higher sensitivity
15-
VAD_SILENCE_DURATION_S = 0.3 # Shorter silence duration to split segments
16-
VAD_MIN_SPEECH_DURATION_S = 0.2
64+
SUBTITLE_MAX_BUFFER_DURATION: float = 15.0 # 15s for real-time response
65+
VAD_SPEECH_THRESHOLD: float = 0.4 # Lower threshold for higher sensitivity
66+
VAD_SILENCE_DURATION_S: float = 0.3 # Shorter silence duration to split segments
67+
VAD_MIN_SPEECH_DURATION_S: float = 0.2
1768

1869
# Subtitle Generation
19-
SUBTITLE_MAX_LENGTH = 80 # Max characters per subtitle line
20-
SUBTITLE_MAX_DURATION_S = 15.0 # Max duration of a single subtitle line
21-
SUBTITLE_MAX_SILENCE_S = 0.5 # Max silence to tolerate before creating a new line
70+
SUBTITLE_MAX_LENGTH: int = 80 # Max characters per subtitle line
71+
SUBTITLE_MAX_DURATION_S: float = 15.0 # Max duration of a single subtitle line
72+
SUBTITLE_MAX_SILENCE_S: float = 0.5 # Max silence to tolerate before creating a new line
2273

2374
# Speaker Diarization
24-
MAX_MERGE_GAP_S = 5.0 # Max silence between segments to merge
75+
MAX_MERGE_GAP_S: float = 5.0 # Max silence between segments to merge
2576

2677
# File Transcription Chunking
27-
TARGET_CHUNK_DURATION_S = 90.0
28-
MAX_CHUNK_DURATION_S = 120.0
29-
MIN_SILENCE_GAP_S = 0.5
78+
TARGET_CHUNK_DURATION_S: float = 90.0
79+
MAX_CHUNK_DURATION_S: float = 120.0
80+
MIN_SILENCE_GAP_S: float = 0.5
3081

3182
# Transcription Padding
32-
TRANSCRIPTION_PADDING_S = 1.5 # 1.5s of silence padding
83+
TRANSCRIPTION_PADDING_S: float = 1.5 # 1.5s of silence padding
3384

3485
# List of available Whisper models
35-
MODELS = [
86+
MODELS: List[str] = [
3687
"tiny.en", "tiny", "base.en", "base", "small.en", "small",
3788
"medium.en", "medium", "large-v1", "large-v2", "large-v3", "large",
3889
"distil-large-v2", "distil-medium.en", "distil-small.en"

0 commit comments

Comments
 (0)