Skip to content

Commit 778677d

Browse files
finish updating how env variables are handled including for managing the faiss scoring function
Signed-off-by: thiswillbeyourgithub <[email protected]>
1 parent 8976eb9 commit 778677d

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

wdoc/utils/embeddings.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,7 @@
3030
# from langchain.storage import LocalFileStore
3131
from .customs.compressed_embeddings_cacher import LocalFileStore
3232
from .customs.litellm_embeddings import LiteLLMEmbeddings
33-
from .env import (
34-
WDOC_DEFAULT_EMBED_DIMENSION,
35-
WDOC_DISABLE_EMBEDDINGS_CACHE,
36-
WDOC_EXPIRE_CACHE_DAYS,
37-
WDOC_MOD_FAISS_SCORE_FN,
38-
)
33+
from .env import env
3934
from .flags import is_verbose
4035
from .logger import red, whi, deb
4136
from .misc import ModelName, cache_dir, get_tkn_length, cache_file_in_memory
@@ -51,27 +46,22 @@
5146
)
5247

5348

54-
if WDOC_MOD_FAISS_SCORE_FN:
55-
56-
def score_function(distance: float) -> float:
57-
"""
58-
Scoring function for faiss to make sure it's positive.
59-
Related issue: https://github.com/langchain-ai/langchain/issues/17333
60-
61-
In langchain the default value is the euclidean relevance score:
62-
return 1.0 - distance / math.sqrt(2)
49+
def faiss_custom_score_function(distance: float) -> float:
50+
"""
51+
Scoring function for faiss to make sure it's positive.
52+
Related issue: https://github.com/langchain-ai/langchain/issues/17333
6353
64-
The output is a similarity score: it must be [0,1] such that
65-
0 is the most dissimilar, 1 is the most similar document.
66-
"""
67-
# To disable it but simply check: uncomment this and add "import math"
68-
# assert distance >= 0, distance
69-
# return 1.0 - distance / math.sqrt(2)
70-
new = 1 - ((1 + distance) / 2)
71-
return new
54+
In langchain the default value is the euclidean relevance score:
55+
return 1.0 - distance / math.sqrt(2)
7256
73-
else:
74-
score_function = None
57+
The output is a similarity score: it must be [0,1] such that
58+
0 is the most dissimilar, 1 is the most similar document.
59+
"""
60+
# To disable it but simply check: uncomment this and add "import math"
61+
# assert distance >= 0, distance
62+
# return 1.0 - distance / math.sqrt(2)
63+
new = 1 - ((1 + distance) / 2)
64+
return new
7565

7666

7767
@optional_typecheck
@@ -99,7 +89,7 @@ def load_embeddings_engine(
9989
try:
10090
embeddings = LiteLLMEmbeddings(
10191
model=modelname.original,
102-
dimensions=WDOC_DEFAULT_EMBED_DIMENSION, # defaults to None
92+
dimensions=env.WDOC_DEFAULT_EMBED_DIMENSION, # defaults to None
10393
api_base=api_base,
10494
private=private,
10595
**embed_kwargs,
@@ -129,7 +119,7 @@ def load_embeddings_engine(
129119
model=modelname.model,
130120
openai_api_key=os.environ["OPENAI_API_KEY"],
131121
api_base=api_base,
132-
dimensions=WDOC_DEFAULT_EMBED_DIMENSION, # defaults to None
122+
dimensions=env.WDOC_DEFAULT_EMBED_DIMENSION, # defaults to None
133123
**embed_kwargs,
134124
)
135125

@@ -198,12 +188,12 @@ def load_embeddings_engine(
198188

199189
lfs = LocalFileStore(
200190
database_path=cache_dir / "CacheEmbeddings" / modelname.sanitized,
201-
expiration_days=WDOC_EXPIRE_CACHE_DAYS,
191+
expiration_days=env.WDOC_EXPIRE_CACHE_DAYS,
202192
verbose=is_verbose,
203193
name="Embeddings_" + modelname.sanitized,
204194
)
205195

206-
if WDOC_DISABLE_EMBEDDINGS_CACHE:
196+
if env.WDOC_DISABLE_EMBEDDINGS_CACHE:
207197
whi("Embeddings cache is disabled - using direct embeddings without caching")
208198
cached_embeddings = embeddings
209199
else:
@@ -254,7 +244,9 @@ def create_embeddings(
254244
db = FAISS.load_local(
255245
str(path),
256246
cached_embeddings,
257-
relevance_score_fn=score_function,
247+
relevance_score_fn=(
248+
faiss_custom_score_function if env.WDOC_MOD_FAISS_SCORE_FN else None
249+
),
258250
allow_dangerous_deserialization=True,
259251
)
260252
n_doc = len(db.index_to_docstore_id.keys())
@@ -322,7 +314,11 @@ def embed_one_batch(
322314
batch,
323315
cached_embeddings,
324316
normalize_L2=True,
325-
relevance_score_fn=score_function,
317+
relevance_score_fn=(
318+
faiss_custom_score_function
319+
if env.WDOC_MOD_FAISS_SCORE_FN
320+
else None
321+
),
326322
)
327323
break
328324
except Exception as e:
@@ -335,7 +331,11 @@ def embed_one_batch(
335331
batch,
336332
cached_embeddings.underlying_embeddings,
337333
normalize_L2=True,
338-
relevance_score_fn=score_function,
334+
relevance_score_fn=(
335+
faiss_custom_score_function
336+
if env.WDOC_MOD_FAISS_SCORE_FN
337+
else None
338+
),
339339
)
340340
break
341341
else:

0 commit comments

Comments
 (0)