@@ -45,13 +45,27 @@ def __init__(
45
45
self ._model_dict = {}
46
46
self ._is_cosyvoice_v2 = False
47
47
48
+ self ._parse_and_set_cuda_visible_devices ()
49
+
48
50
cosyvoice_yaml_path = os .path .join (self ._cfg .model , "cosyvoice.yaml" )
49
51
if os .path .exists (cosyvoice_yaml_path ):
50
52
with open (cosyvoice_yaml_path , "r" , encoding = "utf-8" ) as f :
51
53
content = f .read ()
52
54
if re .search (r"Qwen2" , content , re .IGNORECASE ):
53
55
self ._is_cosyvoice_v2 = True
54
56
57
+ def _parse_and_set_cuda_visible_devices (self ):
58
+ """
59
+ Parse CUDA device in format cuda:1 and set CUDA_VISIBLE_DEVICES accordingly.
60
+ """
61
+ device = self ._cfg .device
62
+ if device .startswith ("cuda:" ):
63
+ device_index = device .split (":" )[1 ]
64
+ if device_index .isdigit ():
65
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = device_index
66
+ else :
67
+ raise ValueError (f"Invalid CUDA device index: { device_index } " )
68
+
55
69
def load (self ):
56
70
for path in paths_to_insert :
57
71
sys .path .insert (0 , path )
@@ -66,9 +80,7 @@ def load(self):
66
80
67
81
# CosyVoice2 does not have builtin spk2info.pt
68
82
if not self ._model .frontend .spk2info :
69
- self ._model .frontend .spk2info = torch .load (
70
- builtin_spk2info_path , map_location = self ._cfg .device
71
- )
83
+ self ._model .frontend .spk2info = torch .load (builtin_spk2info_path )
72
84
else :
73
85
from cosyvoice .cli .cosyvoice import CosyVoice as CosyVoiceModel
74
86
0 commit comments