30
30
# from langchain.storage import LocalFileStore
31
31
from .customs .compressed_embeddings_cacher import LocalFileStore
32
32
from .customs .litellm_embeddings import LiteLLMEmbeddings
33
- from .env import (
34
- WDOC_DEFAULT_EMBED_DIMENSION ,
35
- WDOC_DISABLE_EMBEDDINGS_CACHE ,
36
- WDOC_EXPIRE_CACHE_DAYS ,
37
- WDOC_MOD_FAISS_SCORE_FN ,
38
- )
33
+ from .env import env
39
34
from .flags import is_verbose
40
35
from .logger import red , whi , deb
41
36
from .misc import ModelName , cache_dir , get_tkn_length , cache_file_in_memory
51
46
)
52
47
53
48
54
- if WDOC_MOD_FAISS_SCORE_FN :
55
-
56
- def score_function (distance : float ) -> float :
57
- """
58
- Scoring function for faiss to make sure it's positive.
59
- Related issue: https://github.com/langchain-ai/langchain/issues/17333
60
-
61
- In langchain the default value is the euclidean relevance score:
62
- return 1.0 - distance / math.sqrt(2)
49
+ def faiss_custom_score_function (distance : float ) -> float :
50
+ """
51
+ Scoring function for faiss to make sure it's positive.
52
+ Related issue: https://github.com/langchain-ai/langchain/issues/17333
63
53
64
- The output is a similarity score: it must be [0,1] such that
65
- 0 is the most dissimilar, 1 is the most similar document.
66
- """
67
- # To disable it but simply check: uncomment this and add "import math"
68
- # assert distance >= 0, distance
69
- # return 1.0 - distance / math.sqrt(2)
70
- new = 1 - ((1 + distance ) / 2 )
71
- return new
54
+ In langchain the default value is the euclidean relevance score:
55
+ return 1.0 - distance / math.sqrt(2)
72
56
73
- else :
74
- score_function = None
57
+ The output is a similarity score: it must be [0,1] such that
58
+ 0 is the most dissimilar, 1 is the most similar document.
59
+ """
60
+ # To disable it but simply check: uncomment this and add "import math"
61
+ # assert distance >= 0, distance
62
+ # return 1.0 - distance / math.sqrt(2)
63
+ new = 1 - ((1 + distance ) / 2 )
64
+ return new
75
65
76
66
77
67
@optional_typecheck
@@ -99,7 +89,7 @@ def load_embeddings_engine(
99
89
try :
100
90
embeddings = LiteLLMEmbeddings (
101
91
model = modelname .original ,
102
- dimensions = WDOC_DEFAULT_EMBED_DIMENSION , # defaults to None
92
+ dimensions = env . WDOC_DEFAULT_EMBED_DIMENSION , # defaults to None
103
93
api_base = api_base ,
104
94
private = private ,
105
95
** embed_kwargs ,
@@ -129,7 +119,7 @@ def load_embeddings_engine(
129
119
model = modelname .model ,
130
120
openai_api_key = os .environ ["OPENAI_API_KEY" ],
131
121
api_base = api_base ,
132
- dimensions = WDOC_DEFAULT_EMBED_DIMENSION , # defaults to None
122
+ dimensions = env . WDOC_DEFAULT_EMBED_DIMENSION , # defaults to None
133
123
** embed_kwargs ,
134
124
)
135
125
@@ -198,12 +188,12 @@ def load_embeddings_engine(
198
188
199
189
lfs = LocalFileStore (
200
190
database_path = cache_dir / "CacheEmbeddings" / modelname .sanitized ,
201
- expiration_days = WDOC_EXPIRE_CACHE_DAYS ,
191
+ expiration_days = env . WDOC_EXPIRE_CACHE_DAYS ,
202
192
verbose = is_verbose ,
203
193
name = "Embeddings_" + modelname .sanitized ,
204
194
)
205
195
206
- if WDOC_DISABLE_EMBEDDINGS_CACHE :
196
+ if env . WDOC_DISABLE_EMBEDDINGS_CACHE :
207
197
whi ("Embeddings cache is disabled - using direct embeddings without caching" )
208
198
cached_embeddings = embeddings
209
199
else :
@@ -254,7 +244,9 @@ def create_embeddings(
254
244
db = FAISS .load_local (
255
245
str (path ),
256
246
cached_embeddings ,
257
- relevance_score_fn = score_function ,
247
+ relevance_score_fn = (
248
+ faiss_custom_score_function if env .WDOC_MOD_FAISS_SCORE_FN else None
249
+ ),
258
250
allow_dangerous_deserialization = True ,
259
251
)
260
252
n_doc = len (db .index_to_docstore_id .keys ())
@@ -322,7 +314,11 @@ def embed_one_batch(
322
314
batch ,
323
315
cached_embeddings ,
324
316
normalize_L2 = True ,
325
- relevance_score_fn = score_function ,
317
+ relevance_score_fn = (
318
+ faiss_custom_score_function
319
+ if env .WDOC_MOD_FAISS_SCORE_FN
320
+ else None
321
+ ),
326
322
)
327
323
break
328
324
except Exception as e :
@@ -335,7 +331,11 @@ def embed_one_batch(
335
331
batch ,
336
332
cached_embeddings .underlying_embeddings ,
337
333
normalize_L2 = True ,
338
- relevance_score_fn = score_function ,
334
+ relevance_score_fn = (
335
+ faiss_custom_score_function
336
+ if env .WDOC_MOD_FAISS_SCORE_FN
337
+ else None
338
+ ),
339
339
)
340
340
break
341
341
else :
0 commit comments