|
14 | 14 | from langchain_core.callbacks import BaseCallbackHandler
|
15 | 15 | from langchain_core.messages.base import BaseMessage
|
16 | 16 | from langchain_core.outputs.llm_result import LLMResult
|
17 |
| -from langchain_openai import ChatOpenAI |
18 | 17 | from loguru import logger
|
19 | 18 |
|
20 | 19 | from .env import env
|
@@ -46,7 +45,7 @@ def load_llm(
|
46 | 45 | private: bool,
|
47 | 46 | tags: List[str],
|
48 | 47 | **extra_model_args,
|
49 |
| -) -> Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel]: |
| 48 | +) -> Union[ChatLiteLLM, FakeListChatModel]: |
50 | 49 | """load language model"""
|
51 | 50 | if extra_model_args is None:
|
52 | 51 | extra_model_args = {}
|
@@ -145,50 +144,35 @@ def load_llm(
|
145 | 144 |
|
146 | 145 | assert private == env.WDOC_PRIVATE_MODE
|
147 | 146 |
|
148 |
| - if (not private) and (modelname.backend == "openai") and (api_base is None): |
149 |
| - max_tokens = get_model_max_tokens(modelname) |
150 |
| - logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}") |
151 |
| - if "max_tokens" not in extra_model_args: |
| 147 | + max_tokens = get_model_max_tokens(modelname) |
| 148 | + logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}") |
| 149 | + if "max_tokens" not in extra_model_args: |
| 150 | + # intentionaly limiting max tokens because it can cause bugs |
| 151 | + if modelname.backend != "ollama": |
152 | 152 | extra_model_args["max_tokens"] = int(max_tokens * 0.9)
|
153 |
| - logger.debug(f"Using ChatOpenAI backend for model {modelname.original}") |
154 |
| - llm = ChatOpenAI( |
155 |
| - model_name=modelname.model, |
156 |
| - cache=llm_cache, |
157 |
| - disable_streaming=True, # Not needed and might break cache |
158 |
| - verbose=llm_verbosity, |
159 |
| - callbacks=[PriceCountingCallback(verbose=llm_verbosity)] |
160 |
| - + langfuse_callback_holder, # use langchain's callback to langfuse |
161 |
| - **extra_model_args, |
162 |
| - ) |
163 |
| - else: |
164 |
| - max_tokens = get_model_max_tokens(modelname) |
165 |
| - logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}") |
166 |
| - if "max_tokens" not in extra_model_args: |
167 |
| - # intentionaly limiting max tokens because it can cause bugs |
168 |
| - if modelname.backend != "ollama": |
| 153 | + else: |
| 154 | + if max_tokens <= 10_000: |
169 | 155 | extra_model_args["max_tokens"] = int(max_tokens * 0.9)
|
170 | 156 | else:
|
171 |
| - if max_tokens <= 10_000: |
172 |
| - extra_model_args["max_tokens"] = int(max_tokens * 0.9) |
173 |
| - else: |
174 |
| - logger.debug( |
175 |
| - f"Detected an ollama model with large max_tokens ({max_tokens}), they usually overestimate their context window capabilities so we reduce it if the user does not specify a max_tokens kwarg" |
176 |
| - ) |
177 |
| - extra_model_args["max_tokens"] = int(max(max_tokens * 0.2, 4096)) |
178 |
| - logger.debug(f"Using ChatLiteLLM backend for model {modelname.original}") |
179 |
| - llm = ChatLiteLLM( |
180 |
| - model_name=modelname.original, |
181 |
| - disable_streaming=True, # Not needed and might break cache |
182 |
| - api_base=api_base, |
183 |
| - cache=llm_cache, |
184 |
| - verbose=llm_verbosity, |
185 |
| - tags=tags, |
186 |
| - callbacks=[PriceCountingCallback(verbose=llm_verbosity)] |
187 |
| - + langfuse_callback_holder, |
188 |
| - user=env.WDOC_LITELLM_USER, |
189 |
| - **extra_model_args, |
190 |
| - ) |
191 |
| - litellm.drop_params = True |
| 157 | + logger.debug( |
| 158 | + f"Detected an ollama model with large max_tokens ({max_tokens}), they usually overestimate their context window capabilities so we reduce it if the user does not specify a max_tokens kwarg" |
| 159 | + ) |
| 160 | + extra_model_args["max_tokens"] = int(max(max_tokens * 0.2, 4096)) |
| 161 | + logger.debug(f"Using ChatLiteLLM backend for model {modelname.original}") |
| 162 | + llm = ChatLiteLLM( |
| 163 | + model_name=modelname.original, |
| 164 | + disable_streaming=True, # Not needed and might break cache |
| 165 | + api_base=api_base, |
| 166 | + cache=llm_cache, |
| 167 | + verbose=llm_verbosity, |
| 168 | + tags=tags, |
| 169 | + callbacks=[PriceCountingCallback(verbose=llm_verbosity)] |
| 170 | + + langfuse_callback_holder, |
| 171 | + user=env.WDOC_LITELLM_USER, |
| 172 | + **extra_model_args, |
| 173 | + ) |
| 174 | + litellm.drop_params = True |
| 175 | + |
192 | 176 | if private:
|
193 | 177 | assert llm.api_base, "private is set but no api_base for llm were found"
|
194 | 178 | assert (
|
|
0 commit comments