1
- """
2
- Tests for the AbstractGraph.
3
- """
4
-
5
- from unittest .mock import patch
6
-
7
1
import pytest
2
+
8
3
from langchain_aws import ChatBedrock
9
4
from langchain_ollama import ChatOllama
10
5
from langchain_openai import AzureChatOpenAI , ChatOpenAI
11
-
12
6
from scrapegraphai .graphs import AbstractGraph , BaseGraph
13
7
from scrapegraphai .models import DeepSeek , OneApi
14
8
from scrapegraphai .nodes import FetchNode , ParseNode
9
+ from unittest .mock import Mock , patch
15
10
11
+ """
12
+ Tests for the AbstractGraph.
13
+ """
16
14
17
15
class TestGraph (AbstractGraph ):
18
16
def __init__ (self , prompt : str , config : dict ):
@@ -50,7 +48,6 @@ def run(self) -> str:
50
48
51
49
return self .final_state .get ("answer" , "No answer found." )
52
50
53
-
54
51
class TestAbstractGraph :
55
52
@pytest .mark .parametrize (
56
53
"llm_config, expected_model" ,
@@ -161,3 +158,45 @@ async def test_run_safe_async(self):
161
158
result = await graph .run_safe_async ()
162
159
assert result == "Async result"
163
160
mock_run .assert_called_once ()
161
+
162
+ def test_create_llm_with_custom_model_instance (self ):
163
+ """
164
+ Test that the _create_llm method correctly uses a custom model instance
165
+ when provided in the configuration.
166
+ """
167
+ mock_model = Mock ()
168
+ mock_model .model_name = "custom-model"
169
+
170
+ config = {
171
+ "llm" : {
172
+ "model_instance" : mock_model ,
173
+ "model_tokens" : 1000 ,
174
+ "model" : "custom/model"
175
+ }
176
+ }
177
+
178
+ graph = TestGraph ("Test prompt" , config )
179
+
180
+ assert graph .llm_model == mock_model
181
+ assert graph .model_token == 1000
182
+
183
+ def test_set_common_params (self ):
184
+ """
185
+ Test that the set_common_params method correctly updates the configuration
186
+ of all nodes in the graph.
187
+ """
188
+ # Create a mock graph with mock nodes
189
+ mock_graph = Mock ()
190
+ mock_node1 = Mock ()
191
+ mock_node2 = Mock ()
192
+ mock_graph .nodes = [mock_node1 , mock_node2 ]
193
+
194
+ # Create a TestGraph instance with the mock graph
195
+ with patch ('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_graph' , return_value = mock_graph ):
196
+ graph = TestGraph ("Test prompt" , {"llm" : {"model" : "openai/gpt-3.5-turbo" , "openai_api_key" : "sk-test" }})
197
+
198
+ # Call set_common_params with test parameters
199
+ test_params = {"param1" : "value1" , "param2" : "value2" }
200
+ graph .set_common_params (test_params )
201
+
202
+ # Assert that update_config was called on each node with the correct parameters
0 commit comments