Skip to content

Commit 1c8166d

Browse files
committed
refactor: estimate model insequence since file lock takes more time
1 parent 1a5673b commit 1c8166d

File tree

7 files changed

+53
-41
lines changed

7 files changed

+53
-41
lines changed

vox_box/downloader/downloaders.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,48 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
def download_model(
23+
def download_file(
2424
huggingface_repo_id: Optional[str] = None,
2525
huggingface_filename: Optional[str] = None,
2626
model_scope_model_id: Optional[str] = None,
2727
model_scope_file_path: Optional[str] = None,
2828
cache_dir: Optional[str] = None,
2929
huggingface_token: Optional[str] = None,
3030
) -> str:
31+
result_path = None
32+
key = None
33+
3134
if huggingface_repo_id is not None:
32-
return HfDownloader.download(
35+
key = (
36+
f"huggingface:{huggingface_repo_id}"
37+
if huggingface_filename is None
38+
else f"huggingface:{huggingface_repo_id}:{huggingface_filename}"
39+
)
40+
logger.debug(f"Downloading {key}")
41+
42+
result_path = HfDownloader.download(
3343
repo_id=huggingface_repo_id,
3444
filename=huggingface_filename,
3545
token=huggingface_token,
3646
cache_dir=os.path.join(cache_dir, "huggingface"),
3747
)
3848
elif model_scope_model_id is not None:
39-
return ModelScopeDownloader.download(
49+
key = (
50+
f"modelscope:{model_scope_model_id}"
51+
if model_scope_file_path is None
52+
else f"modelscope:{model_scope_model_id}:{model_scope_file_path}"
53+
)
54+
logger.debug(f"Downloading {key}")
55+
56+
result_path = ModelScopeDownloader.download(
4057
model_id=model_scope_model_id,
4158
file_path=model_scope_file_path,
4259
cache_dir=os.path.join(cache_dir, "model_scope"),
4360
)
4461

62+
logger.debug(f"Downloaded {key}")
63+
return result_path
64+
4565

4666
def get_file_size(
4767
huggingface_repo_id: Optional[str] = None,
@@ -150,8 +170,6 @@ def download_file(
150170
if len(matching_files) == 0:
151171
raise ValueError(f"No file found in {repo_id} that match {filename}")
152172

153-
logger.info(f"Downloading model {repo_id}/{filename}")
154-
155173
subfolder, first_filename = (
156174
str(Path(matching_files[0]).parent),
157175
Path(matching_files[0]).name,
@@ -193,7 +211,6 @@ def _inner_hf_hub_download(repo_file: str):
193211
else:
194212
model_path = os.path.join(local_dir, first_filename)
195213

196-
logger.info(f"Downloaded model {repo_id}/{filename}")
197214
return model_path
198215

199216
def __call__(self):
@@ -242,7 +259,6 @@ def download(
242259
name = name.replace(".", "___")
243260
lock_filename = os.path.join(cache_dir, group_or_owner, f"{name}.lock")
244261

245-
logger.info("Retriving file lock")
246262
with FileLock(lock_filename):
247263
if file_path is not None:
248264
matching_files = match_model_scope_file_paths(model_id, file_path)

vox_box/estimator/bark.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from typing import Dict
55
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
6-
from vox_box.downloader.downloaders import download_model
6+
from vox_box.downloader.downloaders import download_file
77
from vox_box.estimator.base import Estimator
88
from vox_box.utils.model import create_model_dict
99

@@ -18,10 +18,8 @@ def __init__(
1818
self._cfg = cfg
1919
self._required_files = [
2020
"config.json",
21-
"speaker_embeddings_path.json",
2221
]
2322
self._config_json = None
24-
self._speaker_json = None
2523

2624
def model_info(self) -> Dict:
2725
model = (
@@ -61,17 +59,13 @@ def _check_local_model(self, base_dir: str) -> bool:
6159
if architectures is not None and "BarkModel" in architectures:
6260
supported = True
6361

64-
speaker_path = os.path.join(base_dir, "speaker_embeddings_path.json")
65-
with open(speaker_path, "r", encoding="utf-8") as f:
66-
self._speaker_json = json.load(f)
67-
6862
return supported
6963

7064
def _check_remote_model(self) -> bool:
7165
downloaded_files = []
7266
for f in self._required_files:
7367
try:
74-
downloaded_file_path = download_model(
68+
downloaded_file_path = download_file(
7569
huggingface_repo_id=self._cfg.huggingface_repo_id,
7670
huggingface_filename=f,
7771
model_scope_model_id=self._cfg.model_scope_model_id,
@@ -80,7 +74,7 @@ def _check_remote_model(self) -> bool:
8074
)
8175
downloaded_files.append(downloaded_file_path)
8276
except Exception as e:
83-
logger.error(f"Failed to download {f}, {e}")
77+
logger.debug(f"File {f} does not exist, {e}")
8478
continue
8579

8680
if len(downloaded_files) != 0:

vox_box/estimator/cosyvoice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Dict
44

5-
from vox_box.downloader.downloaders import download_model
5+
from vox_box.downloader.downloaders import download_file
66
from vox_box.estimator.base import Estimator
77

88
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
@@ -57,15 +57,15 @@ def _check_remote_model(self) -> bool:
5757
downloaded_files = []
5858
for f in self._required_files:
5959
try:
60-
download_file_path = download_model(
60+
download_file_path = download_file(
6161
huggingface_repo_id=self._cfg.huggingface_repo_id,
6262
huggingface_filename=f,
6363
model_scope_model_id=self._cfg.model_scope_model_id,
6464
model_scope_file_path=f,
6565
cache_dir=self._cfg.cache_dir,
6666
)
6767
except Exception as e:
68-
logger.error(f"Failed to download {f}, {e}")
68+
logger.debug(f"File {f} does not exist, {e}")
6969
continue
7070
downloaded_files.append(download_file_path)
7171

vox_box/estimator/estimate.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Dict, List
23
from vox_box.config.config import Config
34
from vox_box.estimator.bark import Bark
@@ -6,14 +7,15 @@
67
from vox_box.estimator.faster_whisper import FasterWhisper
78
from vox_box.estimator.funasr import FunASR
89
from vox_box.utils.model import create_model_dict
9-
from concurrent.futures import ThreadPoolExecutor, as_completed
10+
11+
logger = logging.getLogger(__name__)
1012

1113

1214
def estimate_model(cfg: Config) -> Dict:
1315
estimators: List[Estimator] = [
16+
CosyVoice(cfg),
1417
FasterWhisper(cfg),
1518
FunASR(cfg),
16-
CosyVoice(cfg),
1719
Bark(cfg),
1820
]
1921

@@ -23,17 +25,9 @@ def estimate_model(cfg: Config) -> Dict:
2325
supported=False,
2426
)
2527

26-
def get_model_info(estimator: Estimator) -> Dict:
27-
return estimator.model_info()
28-
29-
with ThreadPoolExecutor() as executor:
30-
futures = {executor.submit(get_model_info, e): e for e in estimators}
31-
for future in as_completed(futures):
32-
result = future.result()
33-
if result["supported"]:
34-
for f in futures:
35-
if not f.done():
36-
f.cancel()
37-
return result
28+
for estimator in estimators:
29+
model_info = estimator.model_info()
30+
if model_info["supported"]:
31+
return model_info
3832

3933
return model_info

vox_box/estimator/faster_whisper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from typing import Dict, List
55
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
6-
from vox_box.downloader.downloaders import download_model
6+
from vox_box.downloader.downloaders import download_file
77
from vox_box.downloader.hub import match_files
88
from vox_box.estimator.base import Estimator
99
from vox_box.utils.model import create_model_dict
@@ -107,22 +107,22 @@ def _check_remote_model(self) -> bool: # noqa: C901
107107
if "model.bin" not in matching_files:
108108
return False
109109
except Exception as e:
110-
logger.error(f"Failed to download model file for estimating, {e}")
110+
logger.debug(f"File model.bin does not exist, {e}")
111111
return False
112112

113113
downloaded_files = []
114114
download_files = ["tokenizer.json", "preprocessor_config.json"]
115115
for f in download_files:
116116
try:
117-
downloaded_file_path = download_model(
117+
downloaded_file_path = download_file(
118118
huggingface_repo_id=self._cfg.huggingface_repo_id,
119119
huggingface_filename=f,
120120
model_scope_model_id=self._cfg.model_scope_model_id,
121121
model_scope_file_path=f,
122122
cache_dir=self._cfg.cache_dir,
123123
)
124124
except Exception as e:
125-
logger.error(f"Failed to download {f} for model estimate, {e}")
125+
logger.debug(f"File {f} does not exist, {e}")
126126
continue
127127

128128
downloaded_files.append(downloaded_file_path)

vox_box/estimator/funasr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import yaml
77
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
8-
from vox_box.downloader.downloaders import download_model
8+
from vox_box.downloader.downloaders import download_file
99
from vox_box.estimator.base import Estimator
1010
from vox_box.utils.model import create_model_dict
1111

@@ -91,15 +91,15 @@ def _check_remote_model(self) -> Tuple[bool, str]:
9191
downloaded_files = []
9292
for f in self._optional_files:
9393
try:
94-
download_file_path = download_model(
94+
download_file_path = download_file(
9595
huggingface_repo_id=self._cfg.huggingface_repo_id,
9696
huggingface_filename=f,
9797
model_scope_model_id=self._cfg.model_scope_model_id,
9898
model_scope_file_path=f,
9999
cache_dir=self._cfg.cache_dir,
100100
)
101101
except Exception as e:
102-
logger.error(f"Failed to download {f} for model estimate, {e}")
102+
logger.debug(f"File {f} does not exist, {e}")
103103
continue
104104
downloaded_files.append(download_file_path)
105105

vox_box/server/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Union
23
from vox_box.backends.stt.base import STTBackend
34
from vox_box.backends.stt.faster_whisper import FasterWhisper
@@ -11,12 +12,17 @@
1112

1213
_instance = None
1314

15+
logger = logging.getLogger(__name__)
16+
1417

1518
class ModelInstance:
1619
def __init__(self, cfg: Config):
1720
self._cfg = cfg
1821
self._backend_framework = None
22+
23+
logger.info("Estimating model")
1924
self._estimate = estimate_model(cfg)
25+
logger.info("Finished estimating model")
2026
if (
2127
self._estimate is None
2228
or not self._estimate.get("supported", False)
@@ -30,7 +36,8 @@ def __init__(self, cfg: Config):
3036
or self._cfg.model_scope_model_id is not None
3137
):
3238
try:
33-
mode_path = downloaders.download_model(
39+
logger.info("Downloading model")
40+
mode_path = downloaders.download_file(
3441
huggingface_repo_id=self._cfg.huggingface_repo_id,
3542
model_scope_model_id=self._cfg.model_scope_model_id,
3643
cache_dir=self._cfg.cache_dir,
@@ -54,6 +61,7 @@ def run(self):
5461

5562
if _instance is None:
5663
try:
64+
logger.info("Loading model")
5765
_instance = self._backend_framework.load()
5866
except Exception as e:
5967
raise Exception(f"Faild to load model, {e}")

0 commit comments

Comments
 (0)