Skip to content

Commit 0c476a4

Browse files
committed
feat: add grok integration
1 parent ec957a5 commit 0c476a4

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pydantic import BaseModel
1414

1515
from ..helpers import models_tokens
16-
from ..models import CLoD, DeepSeek, OneApi
16+
from ..models import CLoD, DeepSeek, OneApi, XAI
1717
from ..utils.logging import set_verbosity_info, set_verbosity_warning
1818

1919

@@ -163,6 +163,7 @@ def _create_llm(self, llm_config: dict) -> object:
163163
"fireworks",
164164
"clod",
165165
"togetherai",
166+
"xai",
166167
}
167168

168169
if "/" in llm_params["model"]:
@@ -217,6 +218,7 @@ def _create_llm(self, llm_config: dict) -> object:
217218
"deepseek",
218219
"togetherai",
219220
"clod",
221+
"xai",
220222
}:
221223
if llm_params["model_provider"] == "bedrock":
222224
llm_params["model_kwargs"] = {
@@ -242,6 +244,9 @@ def _create_llm(self, llm_config: dict) -> object:
242244
elif model_provider == "oneapi":
243245
return OneApi(**llm_params)
244246

247+
elif model_provider == "xai":
248+
return XAI(**llm_params)
249+
245250
elif model_provider == "togetherai":
246251
try:
247252
from langchain_together import ChatTogether

scrapegraphai/helpers/models_tokens.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
"llama3-70b-8192": 8192,
151151
"mixtral-8x7b-32768": 32768,
152152
"gemma-7b-it": 8192,
153-
"claude-3-haiku-20240307'": 8192,
153+
"claude-3-haiku-20240307": 8192,
154154
},
155155
"toghetherai": {
156156
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 128000,
@@ -303,4 +303,7 @@
303303
"grok-2-latest": 128000,
304304
},
305305
"togetherai": {"Meta-Llama-3.1-70B-Instruct-Turbo": 128000},
306+
"xai": {
307+
"grok-1": 8192
308+
},
306309
}

scrapegraphai/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from .oneapi import OneApi
88
from .openai_itt import OpenAIImageToText
99
from .openai_tts import OpenAITextToSpeech
10+
from .xai import XAI
1011

11-
__all__ = ["DeepSeek", "OneApi", "OpenAIImageToText", "OpenAITextToSpeech", "CLoD"]
12+
__all__ = ["DeepSeek", "OneApi", "OpenAIImageToText", "OpenAITextToSpeech", "CLoD", "XAI"]

scrapegraphai/models/xai.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
xAI Grok Module
3+
"""
4+
from langchain_groq import ChatGroq as LangchainChatGroq
5+
6+
class XAI(LangchainChatGroq):
7+
"""
8+
Wrapper for the ChatGroq class from langchain_groq, for use with xAI models.
9+
Handles API key mapping from generic 'api_key' to 'groq_api_key' and
10+
maps 'model' to 'model_name'.
11+
12+
Args:
13+
llm_config (dict): Configuration parameters for the language model.
14+
"""
15+
16+
def __init__(self, **llm_config):
17+
if "api_key" in llm_config and "groq_api_key" not in llm_config:
18+
llm_config["groq_api_key"] = llm_config.pop("api_key")
19+
20+
if "model" in llm_config and "model_name" not in llm_config:
21+
llm_config["model_name"] = llm_config.pop("model")
22+
23+
super().__init__(**llm_config)

0 commit comments

Comments
 (0)