@@ -81,39 +81,85 @@ def __repr__(self):
81
81
)
82
82
83
83
84
+ class QueryResultsAggregationEmptyResultsError (Exception ):
85
+ def __init__ (self , namespace : str ):
86
+ super ().__init__ (
87
+ f"Cannot infer metric type from empty query results. Query result for namespace '{ namespace } ' is empty. Have you spelled the namespace name correctly?"
88
+ )
89
+
90
+
91
+ class QueryResultsAggregregatorNotEnoughResultsError (Exception ):
92
+ def __init__ (self , top_k : int , num_results : int ):
93
+ super ().__init__ (
94
+ f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least { top_k } results but got { num_results } ."
95
+ )
96
+
97
+
98
+ class QueryResultsAggregatorInvalidTopKError (Exception ):
99
+ def __init__ (self , top_k : int ):
100
+ super ().__init__ (f"Invalid top_k value { top_k } . top_k must be a positive integer." )
101
+
102
+
84
103
class QueryResultsAggregator :
85
104
def __init__ (self , top_k : int ):
105
+ if top_k < 1 :
106
+ raise QueryResultsAggregatorInvalidTopKError (top_k )
86
107
self .top_k = top_k
87
108
self .usage_read_units = 0
88
109
self .heap = []
89
110
self .insertion_counter = 0
111
+ self .is_dotproduct = None
90
112
self .read = False
91
113
114
+ def __is_dotproduct_index (self , matches ):
115
+ # The interpretation of the score depends on the similar metric used.
116
+ # Unlike other index types, in indexes configured for dotproduct,
117
+ # a higher score is better. We have to infer this is the case by inspecting
118
+ # the order of the scores in the results.
119
+ for i in range (1 , len (matches )):
120
+ if matches [i ].get ("score" ) > matches [i - 1 ].get ("score" ): # Found an increase
121
+ return False
122
+ return True
123
+
92
124
def add_results (self , results : QueryResponse ):
93
125
if self .read :
126
+ # This is mainly just to sanity check in test cases which get quite confusing
127
+ # if you read results twice due to the heap being emptied when constructing
128
+ # the ordered results.
94
129
raise ValueError ("Results have already been read. Cannot add more results." )
95
130
96
- self . usage_read_units + = results .get ("usage " , {}). get ( "readUnits" , 0 )
131
+ matches = results .get ("matches " , [] )
97
132
ns = results .get ("namespace" )
98
- for match in results .get ("matches" , []):
99
- self .insertion_counter += 1
100
- score = match .get ("score" )
101
- if len (self .heap ) < self .top_k :
102
- heapq .heappush (self .heap , (- score , - self .insertion_counter , match , ns ))
103
- else :
104
- heapq .heappushpop (self .heap , (- score , - self .insertion_counter , match , ns ))
133
+ self .usage_read_units += results .get ("usage" , {}).get ("readUnits" , 0 )
134
+
135
+ if self .is_dotproduct is None :
136
+ if len (matches ) == 0 :
137
+ raise QueryResultsAggregationEmptyResultsError (ns )
138
+ if len (matches ) == 1 :
139
+ raise QueryResultsAggregregatorNotEnoughResultsError (self .top_k , len (matches ))
140
+ self .is_dotproduct = self .__is_dotproduct_index (matches )
141
+
142
+ print ("is_dotproduct:" , self .is_dotproduct )
143
+ if self .is_dotproduct :
144
+ raise NotImplementedError ("Dotproduct indexes are not yet supported." )
145
+ else :
146
+ for match in matches :
147
+ self .insertion_counter += 1
148
+ score = match .get ("score" )
149
+ if len (self .heap ) < self .top_k :
150
+ heapq .heappush (self .heap , (- score , - self .insertion_counter , match , ns ))
151
+ else :
152
+ heapq .heappushpop (self .heap , (- score , - self .insertion_counter , match , ns ))
105
153
106
154
def get_results (self ) -> CompositeQueryResults :
107
155
if self .read :
108
- # This is mainly just to sanity check in test cases which get quite confusing
109
- # if you read results twice due to the heap being emptied each time you read
110
- # results into an ordered list.
111
- raise ValueError ("Results have already been read. Cannot read again." )
156
+ return self .final_results
112
157
self .read = True
113
158
114
- return CompositeQueryResults (
159
+ self . final_results = CompositeQueryResults (
115
160
usage = Usage (read_units = self .usage_read_units ),
116
161
matches = [
117
162
ScoredVectorWithNamespace (heapq .heappop (self .heap )) for _ in range (len (self .heap ))
118
163
][::- 1 ],
119
164
)
165
+ return self .final_results
0 commit comments