Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
)
from submodules.model import session

import traceback

app = FastAPI()


Expand All @@ -17,6 +19,9 @@ async def handle_db_session(request: Request, call_next):
session_token = general.get_ctx_token()
try:
response = await call_next(request)
except Exception:
print(traceback.format_exc(), flush=True)
response = None
finally:
general.remove_and_refresh_session(session_token)

Expand Down Expand Up @@ -66,6 +71,7 @@ class MostSimilarByEmbeddingRequest(BaseModel):
att_filter: Optional[List[Dict[str, Any]]] = None
threshold: Optional[Union[float, int]] = None
question: Optional[str] = None
user_id: Optional[str] = None


@app.post("/most_similar_by_embedding")
Expand Down Expand Up @@ -99,6 +105,7 @@ def most_similar_by_embedding(
request.att_filter,
request.threshold,
include_scores,
request.user_id,
)

if request.question:
Expand Down
108 changes: 80 additions & 28 deletions neural_search/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@
embedding,
record_label_association,
record,
project,
user,
)
from submodules.model.enums import EmbeddingPlatform, LabelSource
from submodules.model.cognition_objects import group_member
from submodules.model.integration_objects.helper import (
REFINERY_ATTRIBUTE_ACCESS_GROUPS,
REFINERY_ATTRIBUTE_ACCESS_USERS,
)
from submodules.model.enums import EmbeddingPlatform, LabelSource, UserRoles

from .similarity_threshold import SimilarityThreshold, NO_THRESHOLD_INDICATOR
import traceback

port = int(os.environ["QDRANT_PORT"])
qdrant_client = QdrantClient(host="qdrant", port=port, timeout=60)
Expand Down Expand Up @@ -48,9 +56,25 @@ def most_similar_by_embedding(
att_filter: Optional[List[Dict[str, Any]]] = None,
threshold: Optional[float] = None,
include_scores: bool = False,
user_id: Optional[str] = None,
) -> List[str]:
if not is_filter_valid_for_embedding(project_id, embedding_id, att_filter):
return []
if project.check_access_management_active(project_id):
if not user_id:
return []
requesting_user = user.get(user_id)
if not requesting_user:
return []
if requesting_user.role != UserRoles.ENGINEER.value:
check_access = True
group_members = group_member.get_by_user_id(user_id)
group_ids = [str(group_member.group_id) for group_member in group_members]
else:
check_access = False
else:
check_access = False

tmp_limit = limit
has_sub_key = embedding.has_sub_key(project_id, embedding_id)
if has_sub_key:
Expand All @@ -66,14 +90,20 @@ def most_similar_by_embedding(
elif similarity_threshold == NO_THRESHOLD_INDICATOR:
similarity_threshold = None
try:
_filter = __build_filter(att_filter)
if check_access:
_filter = __add_access_management_filter(_filter, group_ids, user_id)

search_result = qdrant_client.search(
collection_name=embedding_id,
query_vector=query_vector,
query_filter=__build_filter(att_filter),
query_filter=_filter,
limit=tmp_limit,
score_threshold=similarity_threshold,
)
except Exception:
except Exception as e:
print(f"Error during search in Qdrant: {e}", flush=True)
print(traceback.format_exc(), flush=True)
return []

if include_scores:
Expand Down Expand Up @@ -118,39 +148,61 @@ def __is_label_filter(key: str) -> bool:
return parts[0] == LABELS_QDRANT


def __build_filter(att_filter: List[Dict[str, Any]]) -> models.Filter:
if att_filter is None or len(att_filter) == 0:
def __build_filter(att_filter: List[Dict[str, Any]]) -> Optional[models.Filter]:
if not att_filter:
return None
must = [__build_filter_item(filter_item) for filter_item in att_filter]
must = [__build_filter_item(item) for item in att_filter]
return models.Filter(must=must)


def __add_access_management_filter(
base_filter: Optional[models.Filter], group_ids: List[str], user_id: str
) -> models.Filter:
access_conditions = [
models.FieldCondition(
key=REFINERY_ATTRIBUTE_ACCESS_GROUPS,
match=models.MatchAny(any=group_ids),
),
models.FieldCondition(
key=REFINERY_ATTRIBUTE_ACCESS_USERS,
match=models.MatchValue(value=user_id),
),
]

if base_filter is None:
return models.Filter(should=access_conditions)

return models.Filter(
must=base_filter.must or [],
should=access_conditions,
)


def __build_filter_item(filter_item: Dict[str, Any]) -> models.FieldCondition:
if isinstance(filter_item["value"], list):
if filter_item.get("type") == "between":
return models.FieldCondition(
key=filter_item["key"],
range=models.Range(
gte=filter_item["value"][0],
lte=filter_item["value"][1],
),
)
else:
should = [
models.FieldCondition(
key=filter_item["key"], match=models.MatchValue(value=value)
)
for value in filter_item["value"]
]
return models.Filter(should=should)
else:
key = filter_item["key"]
value = filter_item["value"]
typ = filter_item.get("type")

# BETWEEN
if isinstance(value, list) and typ == "between":
return models.FieldCondition(
key=key,
range=models.Range(gte=value[0], lte=value[1]),
)

# IN (...)
if isinstance(value, list):
return models.FieldCondition(
key=filter_item["key"],
match=models.MatchValue(
value=filter_item["value"],
),
key=key,
match=models.MatchAny(any=value),
)

# = single value
return models.FieldCondition(
key=key,
match=models.MatchValue(value=value),
)


def recreate_collection(project_id: str, embedding_id: str) -> int:
embedding_item = embedding.get(project_id, embedding_id)
Expand Down