Skip to content

Commit 30a0f0c

Browse files
new: stop using both ChatOpenAI and ChatLiteLLM
ChatLiteLLM seems to now work reliably Signed-off-by: thiswillbeyourgithub <[email protected]>
1 parent 9104f86 commit 30a0f0c

File tree

5 files changed

+34
-59
lines changed

5 files changed

+34
-59
lines changed

tests/test_wdoc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
os.environ["WDOC_TYPECHECKING"] = "crash"
2626

2727
# Default model names if not specified in environment
28-
# openai needs to be specifically tested because it uses the langchain backend ChatOpenai instead of ChatLiteLLM like the others
28+
# we are testing different providers just in case there are unexpected backend issues
2929
WDOC_TEST_OPENAI_MODEL = os.getenv("WDOC_TEST_OPENAI_MODEL", "gpt-4o")
3030
WDOC_TEST_OPENAI_EVAL_MODEL = os.getenv("WDOC_TEST_OPENAI_EVAL_MODEL", "gpt-4o-mini")
3131
WDOC_TEST_OPENAI_EMBED_MODEL = os.getenv(

wdoc/docs/help.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,7 @@
315315
using ntfy.sh to get summaries.
316316

317317
* `--disable_llm_cache`: bool, default `False`
318-
* WARNING: The cache is temporarily ignored in non openaillms
319-
generations because of an error with langchain's ChatLiteLLM.
320-
Basically if you don't use `--private` and use llm form openai,
321-
wdoc will use ChatOpenAI with regular caching, otherwise
322-
we use ChatLiteLLM with LLM caching disabled.
323-
More at https://github.com/langchain-ai/langchain/issues/22389
324-
325-
disable caching for LLM. All caches are stored in the usual
318+
* disables caching for LLM. All caches are stored in the usual
326319
cache folder for your system. This does not disable caching
327320
for documents.
328321

wdoc/utils/llm.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from langchain_core.callbacks import BaseCallbackHandler
1515
from langchain_core.messages.base import BaseMessage
1616
from langchain_core.outputs.llm_result import LLMResult
17-
from langchain_openai import ChatOpenAI
1817
from loguru import logger
1918

2019
from .env import env
@@ -46,7 +45,7 @@ def load_llm(
4645
private: bool,
4746
tags: List[str],
4847
**extra_model_args,
49-
) -> Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel]:
48+
) -> Union[ChatLiteLLM, FakeListChatModel]:
5049
"""load language model"""
5150
if extra_model_args is None:
5251
extra_model_args = {}
@@ -145,50 +144,35 @@ def load_llm(
145144

146145
assert private == env.WDOC_PRIVATE_MODE
147146

148-
if (not private) and (modelname.backend == "openai") and (api_base is None):
149-
max_tokens = get_model_max_tokens(modelname)
150-
logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}")
151-
if "max_tokens" not in extra_model_args:
147+
max_tokens = get_model_max_tokens(modelname)
148+
logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}")
149+
if "max_tokens" not in extra_model_args:
150+
# intentionaly limiting max tokens because it can cause bugs
151+
if modelname.backend != "ollama":
152152
extra_model_args["max_tokens"] = int(max_tokens * 0.9)
153-
logger.debug(f"Using ChatOpenAI backend for model {modelname.original}")
154-
llm = ChatOpenAI(
155-
model_name=modelname.model,
156-
cache=llm_cache,
157-
disable_streaming=True, # Not needed and might break cache
158-
verbose=llm_verbosity,
159-
callbacks=[PriceCountingCallback(verbose=llm_verbosity)]
160-
+ langfuse_callback_holder, # use langchain's callback to langfuse
161-
**extra_model_args,
162-
)
163-
else:
164-
max_tokens = get_model_max_tokens(modelname)
165-
logger.debug(f"Detected max token for model {modelname.original}: {max_tokens}")
166-
if "max_tokens" not in extra_model_args:
167-
# intentionaly limiting max tokens because it can cause bugs
168-
if modelname.backend != "ollama":
153+
else:
154+
if max_tokens <= 10_000:
169155
extra_model_args["max_tokens"] = int(max_tokens * 0.9)
170156
else:
171-
if max_tokens <= 10_000:
172-
extra_model_args["max_tokens"] = int(max_tokens * 0.9)
173-
else:
174-
logger.debug(
175-
f"Detected an ollama model with large max_tokens ({max_tokens}), they usually overestimate their context window capabilities so we reduce it if the user does not specify a max_tokens kwarg"
176-
)
177-
extra_model_args["max_tokens"] = int(max(max_tokens * 0.2, 4096))
178-
logger.debug(f"Using ChatLiteLLM backend for model {modelname.original}")
179-
llm = ChatLiteLLM(
180-
model_name=modelname.original,
181-
disable_streaming=True, # Not needed and might break cache
182-
api_base=api_base,
183-
cache=llm_cache,
184-
verbose=llm_verbosity,
185-
tags=tags,
186-
callbacks=[PriceCountingCallback(verbose=llm_verbosity)]
187-
+ langfuse_callback_holder,
188-
user=env.WDOC_LITELLM_USER,
189-
**extra_model_args,
190-
)
191-
litellm.drop_params = True
157+
logger.debug(
158+
f"Detected an ollama model with large max_tokens ({max_tokens}), they usually overestimate their context window capabilities so we reduce it if the user does not specify a max_tokens kwarg"
159+
)
160+
extra_model_args["max_tokens"] = int(max(max_tokens * 0.2, 4096))
161+
logger.debug(f"Using ChatLiteLLM backend for model {modelname.original}")
162+
llm = ChatLiteLLM(
163+
model_name=modelname.original,
164+
disable_streaming=True, # Not needed and might break cache
165+
api_base=api_base,
166+
cache=llm_cache,
167+
verbose=llm_verbosity,
168+
tags=tags,
169+
callbacks=[PriceCountingCallback(verbose=llm_verbosity)]
170+
+ langfuse_callback_holder,
171+
user=env.WDOC_LITELLM_USER,
172+
**extra_model_args,
173+
)
174+
litellm.drop_params = True
175+
192176
if private:
193177
assert llm.api_base, "private is set but no api_base for llm were found"
194178
assert (

wdoc/utils/retrievers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# from langchain.storage import LocalFileStore
1111
from langchain_community.chat_models import ChatLiteLLM
1212
from langchain_core.retrievers import BaseRetriever
13-
from langchain_openai import ChatOpenAI
1413

1514
from .misc import cache_dir, get_splitter
1615
from .prompts import multiquery_parser, prompts
@@ -20,7 +19,7 @@
2019

2120
@optional_typecheck
2221
def create_multiquery_retriever(
23-
llm: Union[ChatLiteLLM, ChatOpenAI],
22+
llm: Union[ChatLiteLLM],
2423
retriever: BaseRetriever,
2524
) -> MultiQueryRetriever:
2625
# advanced mode using pydantic parsers

wdoc/utils/tasks/query.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from langchain_community.chat_models.fake import FakeListChatModel
1919
from langchain_core.runnables import chain
2020
from langchain_core.runnables.base import RunnableLambda
21-
from langchain_openai import ChatOpenAI
2221
from numpy.typing import NDArray
2322
from tqdm import tqdm
2423
from loguru import logger
@@ -537,7 +536,7 @@ def semantic_batching(
537536

538537
@optional_typecheck
539538
def pbar_chain(
540-
llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel],
539+
llm: Union[ChatLiteLLM, FakeListChatModel],
541540
len_func: str,
542541
**tqdm_kwargs,
543542
) -> RunnableLambda:
@@ -546,7 +545,7 @@ def pbar_chain(
546545
@chain
547546
def actual_pbar_chain(
548547
inputs: Union[dict, List],
549-
llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel] = llm,
548+
llm: Union[ChatLiteLLM, FakeListChatModel] = llm,
550549
) -> Union[dict, List]:
551550

552551
llm.callbacks[0].pbar.append(
@@ -565,14 +564,14 @@ def actual_pbar_chain(
565564

566565
@optional_typecheck
567566
def pbar_closer(
568-
llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel],
567+
llm: Union[ChatLiteLLM, FakeListChatModel],
569568
) -> RunnableLambda:
570569
"close a pbar created by pbar_chain"
571570

572571
@chain
573572
def actual_pbar_closer(
574573
inputs: Union[dict, List],
575-
llm: Union[ChatLiteLLM, ChatOpenAI, FakeListChatModel] = llm,
574+
llm: Union[ChatLiteLLM, FakeListChatModel] = llm,
576575
) -> Union[dict, List]:
577576
pbar = llm.callbacks[0].pbar[-1]
578577
pbar.update(pbar.total - pbar.n)

0 commit comments

Comments
 (0)