|
| 1 | +from typing import List, Tuple, Optional, Any, Dict |
| 2 | +import json |
| 3 | +import heapq |
| 4 | +from pinecone.core.openapi.data.models import Usage |
| 5 | +from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse |
| 6 | + |
| 7 | +from dataclasses import dataclass, asdict |
| 8 | + |
| 9 | + |
| 10 | +@dataclass |
| 11 | +class ScoredVectorWithNamespace: |
| 12 | + namespace: str |
| 13 | + score: float |
| 14 | + id: str |
| 15 | + values: List[float] |
| 16 | + sparse_values: dict |
| 17 | + metadata: dict |
| 18 | + |
| 19 | + def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): |
| 20 | + json_vector = aggregate_results_heap_tuple[2] |
| 21 | + self.namespace = aggregate_results_heap_tuple[3] |
| 22 | + self.id = json_vector.get("id") # type: ignore |
| 23 | + self.score = json_vector.get("score") # type: ignore |
| 24 | + self.values = json_vector.get("values") # type: ignore |
| 25 | + self.sparse_values = json_vector.get("sparse_values", None) # type: ignore |
| 26 | + self.metadata = json_vector.get("metadata", None) # type: ignore |
| 27 | + |
| 28 | + def __getitem__(self, key): |
| 29 | + if hasattr(self, key): |
| 30 | + return getattr(self, key) |
| 31 | + else: |
| 32 | + raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace") |
| 33 | + |
| 34 | + def get(self, key, default=None): |
| 35 | + return getattr(self, key, default) |
| 36 | + |
| 37 | + def __repr__(self): |
| 38 | + return json.dumps(self._truncate(asdict(self)), indent=4) |
| 39 | + |
| 40 | + def __json__(self): |
| 41 | + return self._truncate(asdict(self)) |
| 42 | + |
| 43 | + def _truncate(self, obj, max_items=2): |
| 44 | + """ |
| 45 | + Recursively traverse and truncate lists that exceed max_items length. |
| 46 | + Only display the "... X more" message if at least 2 elements are hidden. |
| 47 | + """ |
| 48 | + if obj is None: |
| 49 | + return None # Skip None values |
| 50 | + elif isinstance(obj, list): |
| 51 | + filtered_list = [self._truncate(i, max_items) for i in obj if i is not None] |
| 52 | + if len(filtered_list) > max_items: |
| 53 | + # Show the truncation message only if more than 1 item is hidden |
| 54 | + remaining_items = len(filtered_list) - max_items |
| 55 | + if remaining_items > 1: |
| 56 | + return filtered_list[:max_items] + [f"... {remaining_items} more"] |
| 57 | + else: |
| 58 | + # If only 1 item remains, show it |
| 59 | + return filtered_list |
| 60 | + return filtered_list |
| 61 | + elif isinstance(obj, dict): |
| 62 | + # Recursively process dictionaries, omitting None values |
| 63 | + return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None} |
| 64 | + return obj |
| 65 | + |
| 66 | + |
| 67 | +@dataclass |
| 68 | +class QueryNamespacesResults: |
| 69 | + usage: Usage |
| 70 | + matches: List[ScoredVectorWithNamespace] |
| 71 | + |
| 72 | + def __getitem__(self, key): |
| 73 | + if hasattr(self, key): |
| 74 | + return getattr(self, key) |
| 75 | + else: |
| 76 | + raise KeyError(f"'{key}' not found in QueryNamespacesResults") |
| 77 | + |
| 78 | + def get(self, key, default=None): |
| 79 | + return getattr(self, key, default) |
| 80 | + |
| 81 | + def __repr__(self): |
| 82 | + return json.dumps( |
| 83 | + { |
| 84 | + "usage": self.usage.to_dict(), |
| 85 | + "matches": [match.__json__() for match in self.matches], |
| 86 | + }, |
| 87 | + indent=4, |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +class QueryResultsAggregregatorNotEnoughResultsError(Exception): |
| 92 | + def __init__(self): |
| 93 | + super().__init__( |
| 94 | + "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +class QueryResultsAggregatorInvalidTopKError(Exception): |
| 99 | + def __init__(self, top_k: int): |
| 100 | + super().__init__( |
| 101 | + f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +class QueryResultsAggregator: |
| 106 | + def __init__(self, top_k: int): |
| 107 | + if top_k < 2: |
| 108 | + raise QueryResultsAggregatorInvalidTopKError(top_k) |
| 109 | + self.top_k = top_k |
| 110 | + self.usage_read_units = 0 |
| 111 | + self.heap: List[Tuple[float, int, object, str]] = [] |
| 112 | + self.insertion_counter = 0 |
| 113 | + self.is_dotproduct = None |
| 114 | + self.read = False |
| 115 | + self.final_results: Optional[QueryNamespacesResults] = None |
| 116 | + |
| 117 | + def _is_dotproduct_index(self, matches): |
| 118 | + # The interpretation of the score depends on the similar metric used. |
| 119 | + # Unlike other index types, in indexes configured for dotproduct, |
| 120 | + # a higher score is better. We have to infer this is the case by inspecting |
| 121 | + # the order of the scores in the results. |
| 122 | + for i in range(1, len(matches)): |
| 123 | + if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase |
| 124 | + return False |
| 125 | + return True |
| 126 | + |
| 127 | + def _dotproduct_heap_item(self, match, ns): |
| 128 | + return (match.get("score"), -self.insertion_counter, match, ns) |
| 129 | + |
| 130 | + def _non_dotproduct_heap_item(self, match, ns): |
| 131 | + return (-match.get("score"), -self.insertion_counter, match, ns) |
| 132 | + |
| 133 | + def _process_matches(self, matches, ns, heap_item_fn): |
| 134 | + for match in matches: |
| 135 | + self.insertion_counter += 1 |
| 136 | + if len(self.heap) < self.top_k: |
| 137 | + heapq.heappush(self.heap, heap_item_fn(match, ns)) |
| 138 | + else: |
| 139 | + # Assume we have dotproduct scores sorted in descending order |
| 140 | + if self.is_dotproduct and match["score"] < self.heap[0][0]: |
| 141 | + # No further matches can improve the top-K heap |
| 142 | + break |
| 143 | + elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: |
| 144 | + # No further matches can improve the top-K heap |
| 145 | + break |
| 146 | + heapq.heappushpop(self.heap, heap_item_fn(match, ns)) |
| 147 | + |
| 148 | + def add_results(self, results: Dict[str, Any]): |
| 149 | + if self.read: |
| 150 | + # This is mainly just to sanity check in test cases which get quite confusing |
| 151 | + # if you read results twice due to the heap being emptied when constructing |
| 152 | + # the ordered results. |
| 153 | + raise ValueError("Results have already been read. Cannot add more results.") |
| 154 | + |
| 155 | + matches = results.get("matches", []) |
| 156 | + ns: str = results.get("namespace", "") |
| 157 | + if isinstance(results, OpenAPIQueryResponse): |
| 158 | + self.usage_read_units += results.usage.read_units |
| 159 | + else: |
| 160 | + self.usage_read_units += results.get("usage", {}).get("readUnits", 0) |
| 161 | + |
| 162 | + if len(matches) == 0: |
| 163 | + return |
| 164 | + |
| 165 | + if self.is_dotproduct is None: |
| 166 | + if len(matches) == 1: |
| 167 | + # This condition should match the second time we add results containing |
| 168 | + # only one match. We need at least two matches in a single response in order |
| 169 | + # to infer the similarity metric |
| 170 | + raise QueryResultsAggregregatorNotEnoughResultsError() |
| 171 | + self.is_dotproduct = self._is_dotproduct_index(matches) |
| 172 | + |
| 173 | + if self.is_dotproduct: |
| 174 | + self._process_matches(matches, ns, self._dotproduct_heap_item) |
| 175 | + else: |
| 176 | + self._process_matches(matches, ns, self._non_dotproduct_heap_item) |
| 177 | + |
| 178 | + def get_results(self) -> QueryNamespacesResults: |
| 179 | + if self.read: |
| 180 | + if self.final_results is not None: |
| 181 | + return self.final_results |
| 182 | + else: |
| 183 | + # I don't think this branch can ever actually be reached, but the type checker disagrees |
| 184 | + raise ValueError("Results have already been read. Cannot get results again.") |
| 185 | + self.read = True |
| 186 | + |
| 187 | + self.final_results = QueryNamespacesResults( |
| 188 | + usage=Usage(read_units=self.usage_read_units), |
| 189 | + matches=[ |
| 190 | + ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) |
| 191 | + ][::-1], |
| 192 | + ) |
| 193 | + return self.final_results |
0 commit comments