@@ -50,6 +50,22 @@ def _base_url(self) -> str:
50
50
f"{ self .collection } "
51
51
)
52
52
53
+ @property
54
+ def _count_url (self ) -> str :
55
+ """Url to count records."""
56
+ return (
57
+ f"https://{ self .host } /{ self .version } /{ self .database } /"
58
+ f"{ self .collection } /count_documents"
59
+ )
60
+
61
+ @property
62
+ def _find_url (self ) -> str :
63
+ """Url to find records."""
64
+ return (
65
+ f"https://{ self .host } /{ self .version } /{ self .database } /"
66
+ f"{ self .collection } /find"
67
+ )
68
+
53
69
@property
54
70
def _aggregate_url (self ) -> str :
55
71
"""Url to aggregate records."""
@@ -150,12 +166,56 @@ def _count_records(self, filter_query: Optional[dict] = None):
150
166
Has keys {"total_record_count": int, "filtered_record_count": int}
151
167
152
168
"""
153
- params = {
154
- "count_records" : str (True ),
155
- }
169
+ params = (
170
+ {"filter" : json .dumps (filter_query )}
171
+ if filter_query is not None
172
+ else None
173
+ )
174
+ response = self .session .get (self ._count_url , params = params )
175
+ response .raise_for_status ()
176
+ response_body = response .json ()
177
+ return response_body
178
+
179
+ def _find_records (
180
+ self ,
181
+ filter_query : Optional [dict ] = None ,
182
+ projection : Optional [dict ] = None ,
183
+ sort : Optional [dict ] = None ,
184
+ limit : int = 0 ,
185
+ skip : int = 0 ,
186
+ ) -> List [dict ]:
187
+ """
188
+ Retrieve records from collection. May return a smaller set of records
189
+ if requested records exceed the max payload size of the API Gateway.
190
+
191
+ Parameters
192
+ ----------
193
+ filter_query : Optional[dict]
194
+ Filter to apply to the records being returned. Default is None.
195
+ projection : Optional[dict]
196
+ Subset of document fields to return. Default is None.
197
+ sort : Optional[dict]
198
+ Sort records when returned. Default is None.
199
+ limit : int
200
+ Return a smaller set of records. 0 for all records. Default is 0.
201
+ skip : int
202
+ Skip this amount of records in index when applying search.
203
+
204
+ Returns
205
+ -------
206
+ List[dict]
207
+ The list of records returned from the DocumentDB.
208
+
209
+ """
210
+ params = {"limit" : str (limit ), "skip" : str (skip )}
156
211
if filter_query is not None :
157
212
params ["filter" ] = json .dumps (filter_query )
158
- response = self .session .get (self ._base_url , params = params )
213
+ if projection is not None :
214
+ params ["projection" ] = json .dumps (projection )
215
+ if sort is not None :
216
+ params ["sort" ] = json .dumps (sort )
217
+
218
+ response = self .session .get (self ._find_url , params = params )
159
219
response .raise_for_status ()
160
220
response_body = response .json ()
161
221
return response_body
@@ -169,7 +229,9 @@ def _get_records(
169
229
skip : int = 0 ,
170
230
) -> List [dict ]:
171
231
"""
172
- Retrieve records from collection.
232
+ Retrieve records from collection. May raise HTTP 413 error if
233
+ requested records exceed the max payload size of the API Gateway.
234
+
173
235
Parameters
174
236
----------
175
237
filter_query : Optional[dict]
@@ -306,12 +368,13 @@ def retrieve_docdb_records(
306
368
projection : Optional [dict ] = None ,
307
369
sort : Optional [dict ] = None ,
308
370
limit : int = 0 ,
309
- paginate : bool = True ,
310
- paginate_batch_size : int = 500 ,
311
- paginate_max_iterations : int = 20000 ,
371
+ paginate : Optional [ bool ] = None ,
372
+ paginate_batch_size : Optional [ int ] = None ,
373
+ paginate_max_iterations : Optional [ int ] = None ,
312
374
) -> List [dict ]:
313
375
"""
314
376
Retrieve raw json records from DocDB API Gateway as a list of dicts.
377
+ Queries to the API Gateway are paginated.
315
378
316
379
Parameters
317
380
----------
@@ -324,72 +387,40 @@ def retrieve_docdb_records(
324
387
limit : int
325
388
Return a smaller set of records. 0 for all records. Default is 0.
326
389
paginate : bool
327
- If set to true, will batch the queries to the API Gateway. It may
328
- be faster to set to false if the number of records expected to be
329
- returned is small.
390
+ (deprecated) If set to true, will batch the queries to the API
391
+ Gateway.
330
392
paginate_batch_size : int
331
- Number of records to return at a time. Default is 500.
393
+ (deprecated) Number of records to return at a time. Default is 500.
332
394
paginate_max_iterations : int
333
- Max number of iterations to run to prevent indefinite calls to the
334
- API Gateway. Default is 20000.
395
+ (deprecated) Max number of iterations to run to prevent indefinite
396
+ calls to the API Gateway. Default is 20000.
335
397
336
398
Returns
337
399
-------
338
400
List[dict]
339
401
340
402
"""
341
- if paginate is False :
342
- records = self ._get_records (
403
+ get_all_records = True if limit == 0 else False
404
+ records = []
405
+ skip = 0
406
+ while get_all_records or limit > 0 :
407
+ batched_records = self ._find_records (
343
408
filter_query = filter_query ,
344
409
projection = projection ,
345
410
sort = sort ,
346
411
limit = limit ,
412
+ skip = skip ,
347
413
)
348
- else :
349
- # Get record count
350
- record_counts = self ._count_records (filter_query )
351
- total_record_count = record_counts ["total_record_count" ]
352
- filtered_record_count = record_counts ["filtered_record_count" ]
353
- if filtered_record_count <= paginate_batch_size :
354
- records = self ._get_records (
355
- filter_query = filter_query ,
356
- projection = projection ,
357
- sort = sort ,
358
- limit = limit ,
359
- )
360
- else :
361
- records = []
362
- errors = []
363
- num_of_records_collected = 0
364
- limit = filtered_record_count if limit == 0 else limit
365
- skip = 0
366
- iter_count = 0
367
- while (
368
- skip < total_record_count
369
- and num_of_records_collected
370
- < min (filtered_record_count , limit )
371
- and iter_count < paginate_max_iterations
372
- ):
373
- try :
374
- batched_records = self ._get_records (
375
- filter_query = filter_query ,
376
- projection = projection ,
377
- sort = sort ,
378
- limit = paginate_batch_size ,
379
- skip = skip ,
380
- )
381
- num_of_records_collected += len (batched_records )
382
- records .extend (batched_records )
383
- except Exception as e :
384
- errors .append (repr (e ))
385
- skip = skip + paginate_batch_size
386
- iter_count += 1
387
- # TODO: Add optional progress bar?
388
- records = records [0 :limit ]
389
- if len (errors ) > 0 :
390
- logging .error (
391
- f"There were errors retrieving records. { errors } "
392
- )
414
+ batch_size = len (batched_records )
415
+ logging .debug (
416
+ f"(skip={ skip } , limit={ limit } ): Retrieved { batch_size } records"
417
+ )
418
+ if batch_size == 0 :
419
+ break
420
+ records .extend (batched_records )
421
+ skip += batch_size
422
+ if not get_all_records :
423
+ limit -= batch_size
393
424
return records
394
425
395
426
def aggregate_docdb_records (self , pipeline : List [dict ]) -> List [dict ]:
0 commit comments