Skip to content

Commit bd69a40

Browse files
committed
fix: set cuda visible devices for cosyvoice
1 parent 10c3246 commit bd69a40

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

vox_box/backends/tts/cosyvoice.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,27 @@ def __init__(
4545
self._model_dict = {}
4646
self._is_cosyvoice_v2 = False
4747

48+
self._parse_and_set_cuda_visible_devices()
49+
4850
cosyvoice_yaml_path = os.path.join(self._cfg.model, "cosyvoice.yaml")
4951
if os.path.exists(cosyvoice_yaml_path):
5052
with open(cosyvoice_yaml_path, "r", encoding="utf-8") as f:
5153
content = f.read()
5254
if re.search(r"Qwen2", content, re.IGNORECASE):
5355
self._is_cosyvoice_v2 = True
5456

57+
def _parse_and_set_cuda_visible_devices(self):
58+
"""
59+
Parse CUDA device in format cuda:1 and set CUDA_VISIBLE_DEVICES accordingly.
60+
"""
61+
device = self._cfg.device
62+
if device.startswith("cuda:"):
63+
device_index = device.split(":")[1]
64+
if device_index.isdigit():
65+
os.environ["CUDA_VISIBLE_DEVICES"] = device_index
66+
else:
67+
raise ValueError(f"Invalid CUDA device index: {device_index}")
68+
5569
def load(self):
5670
for path in paths_to_insert:
5771
sys.path.insert(0, path)
@@ -66,9 +80,7 @@ def load(self):
6680

6781
# CosyVoice2 does not have builtin spk2info.pt
6882
if not self._model.frontend.spk2info:
69-
self._model.frontend.spk2info = torch.load(
70-
builtin_spk2info_path, map_location=self._cfg.device
71-
)
83+
self._model.frontend.spk2info = torch.load(builtin_spk2info_path)
7284
else:
7385
from cosyvoice.cli.cosyvoice import CosyVoice as CosyVoiceModel
7486

0 commit comments

Comments
 (0)