Skip to content

Commit 9fa8faa

Browse files
authored
Expand user directory for basepath in extra_models_paths.yaml (#4857)
* Expand user path. * Add test. * Add unit test for expanding base path. * Simplify unit test. * Remove comment. * Remove comment. * Checkpoints. * Refactor.
1 parent 9a7444e commit 9fa8faa

File tree

4 files changed

+97
-24
lines changed

4 files changed

+97
-24
lines changed

main.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def execute_script(script_path):
6363
import gc
6464

6565
import logging
66+
from utils import extra_config
6667

6768
if os.name == "nt":
6869
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -85,7 +86,6 @@ def execute_script(script_path):
8586
pass
8687

8788
import comfy.utils
88-
import yaml
8989

9090
import execution
9191
import server
@@ -180,27 +180,6 @@ def cleanup_temp():
180180
shutil.rmtree(temp_dir, ignore_errors=True)
181181

182182

183-
def load_extra_path_config(yaml_path):
184-
with open(yaml_path, 'r') as stream:
185-
config = yaml.safe_load(stream)
186-
for c in config:
187-
conf = config[c]
188-
if conf is None:
189-
continue
190-
base_path = None
191-
if "base_path" in conf:
192-
base_path = conf.pop("base_path")
193-
for x in conf:
194-
for y in conf[x].split("\n"):
195-
if len(y) == 0:
196-
continue
197-
full_path = y
198-
if base_path is not None:
199-
full_path = os.path.join(base_path, full_path)
200-
logging.info("Adding extra search path {} {}".format(x, full_path))
201-
folder_paths.add_model_folder_path(x, full_path)
202-
203-
204183
if __name__ == "__main__":
205184
if args.temp_directory:
206185
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
@@ -222,11 +201,11 @@ def load_extra_path_config(yaml_path):
222201

223202
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
224203
if os.path.isfile(extra_model_paths_config_path):
225-
load_extra_path_config(extra_model_paths_config_path)
204+
extra_config.load_extra_path_config(extra_model_paths_config_path)
226205

227206
if args.extra_model_paths_config:
228207
for config_path in itertools.chain(*args.extra_model_paths_config):
229-
load_extra_path_config(config_path)
208+
extra_config.load_extra_path_config(config_path)
230209

231210
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
232211

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
import yaml
3+
import os
4+
from unittest.mock import Mock, patch, mock_open
5+
6+
from utils.extra_config import load_extra_path_config
7+
import folder_paths
8+
9+
@pytest.fixture
10+
def mock_yaml_content():
11+
return {
12+
'test_config': {
13+
'base_path': '~/App/',
14+
'checkpoints': 'subfolder1',
15+
}
16+
}
17+
18+
@pytest.fixture
19+
def mock_expanded_home():
20+
return '/home/user'
21+
22+
@pytest.fixture
23+
def mock_add_model_folder_path():
24+
return Mock()
25+
26+
@pytest.fixture
27+
def mock_expanduser(mock_expanded_home):
28+
def _expanduser(path):
29+
if path.startswith('~/'):
30+
return os.path.join(mock_expanded_home, path[2:])
31+
return path
32+
return _expanduser
33+
34+
@pytest.fixture
35+
def mock_yaml_safe_load(mock_yaml_content):
36+
return Mock(return_value=mock_yaml_content)
37+
38+
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
39+
def test_load_extra_model_paths_expands_userpath(
40+
mock_file,
41+
monkeypatch,
42+
mock_add_model_folder_path,
43+
mock_expanduser,
44+
mock_yaml_safe_load,
45+
mock_expanded_home
46+
):
47+
# Attach mocks used by load_extra_path_config
48+
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
49+
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
50+
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)
51+
52+
dummy_yaml_file_name = 'dummy_path.yaml'
53+
load_extra_path_config(dummy_yaml_file_name)
54+
55+
expected_calls = [
56+
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1')),
57+
]
58+
59+
assert mock_add_model_folder_path.call_count == len(expected_calls)
60+
61+
# Check if add_model_folder_path was called with the correct arguments
62+
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
63+
assert actual_call.args == expected_call
64+
65+
# Check if yaml.safe_load was called
66+
mock_yaml_safe_load.assert_called_once()
67+
68+
# Check if open was called with the correct file path
69+
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')

utils/__init__.py

Whitespace-only changes.

utils/extra_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
import yaml
3+
import folder_paths
4+
import logging
5+
6+
def load_extra_path_config(yaml_path):
7+
with open(yaml_path, 'r') as stream:
8+
config = yaml.safe_load(stream)
9+
for c in config:
10+
conf = config[c]
11+
if conf is None:
12+
continue
13+
base_path = None
14+
if "base_path" in conf:
15+
base_path = conf.pop("base_path")
16+
base_path = os.path.expanduser(base_path)
17+
for x in conf:
18+
for y in conf[x].split("\n"):
19+
if len(y) == 0:
20+
continue
21+
full_path = y
22+
if base_path is not None:
23+
full_path = os.path.join(base_path, full_path)
24+
logging.info("Adding extra search path {} {}".format(x, full_path))
25+
folder_paths.add_model_folder_path(x, full_path)

0 commit comments

Comments
 (0)