diff --git a/docs/source/ExamplesDocDBRestApi.rst b/docs/source/ExamplesDocDBRestApi.rst index 9c8ae3f..38fa8d6 100644 --- a/docs/source/ExamplesDocDBRestApi.rst +++ b/docs/source/ExamplesDocDBRestApi.rst @@ -18,13 +18,12 @@ Count Example 1: Get # of records with a certain subject_id from aind_data_access_api.document_db import MetadataDbClient API_GATEWAY_HOST = "api.allenneuraldynamics.org" - DATABASE = "metadata_index" - COLLECTION = "data_assets" + # Default database and collection names are set in the client + # To override the defaults, provide the database and collection + # parameters in the constructor docdb_api_client = MetadataDbClient( host=API_GATEWAY_HOST, - database=DATABASE, - collection=COLLECTION, ) filter = {"subject.subject_id": "731015"} @@ -136,8 +135,6 @@ It's possible to attach a custom Session to retry certain requests errors: from aind_data_access_api.document_db import MetadataDbClient API_GATEWAY_HOST = "api.allenneuraldynamics.org" - DATABASE = "metadata_index" - COLLECTION = "data_assets" retry = Retry( total=5, @@ -151,8 +148,6 @@ It's possible to attach a custom Session to retry certain requests errors: with MetadataDbClient( host=API_GATEWAY_HOST, - database=DATABASE, - collection=COLLECTION, session=session, ) as docdb_api_client: records = docdb_api_client.retrieve_docdb_records(limit=10) diff --git a/docs/source/UserGuide.rst b/docs/source/UserGuide.rst index d03daf3..b4a11f7 100644 --- a/docs/source/UserGuide.rst +++ b/docs/source/UserGuide.rst @@ -54,13 +54,12 @@ REST API from aind_data_access_api.document_db import MetadataDbClient API_GATEWAY_HOST = "api.allenneuraldynamics.org" - DATABASE = "metadata_index" - COLLECTION = "data_assets" + # Default database and collection names are set in the client + # To override the defaults, provide the database and collection + # parameters in the constructor docdb_api_client = MetadataDbClient( host=API_GATEWAY_HOST, - database=DATABASE, - collection=COLLECTION, ) filter = {"subject.subject_id": "731015"} diff --git a/src/aind_data_access_api/document_db.py b/src/aind_data_access_api/document_db.py index d6b68ed..1771b16 100644 --- a/src/aind_data_access_api/document_db.py +++ b/src/aind_data_access_api/document_db.py @@ -362,6 +362,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): class MetadataDbClient(Client): """Class to manage reading and writing to metadata db""" + def __init__( + self, + host: str, + database: str = "metadata_index", + collection: str = "data_assets", + version: str = "v1", + boto: Optional[BotoSession] = None, + session: Optional[Session] = None, + ): + """ + Instantiate a MetadataDbClient. + + Parameters + ---------- + host : str + database : str + collection : str + version : str + boto : Optional[BotoSession] + session : Optional[Session] + """ + super().__init__( + host=host, + database=database, + collection=collection, + version=version, + boto=boto, + session=session, + ) + def retrieve_docdb_records( self, filter_query: Optional[dict] = None, diff --git a/tests/test_document_db.py b/tests/test_document_db.py index d4be5a7..4f5dcae 100644 --- a/tests/test_document_db.py +++ b/tests/test_document_db.py @@ -437,8 +437,6 @@ class TestMetadataDbClient(unittest.TestCase): example_client_args = { "host": "example.com/", - "database": "metadata_db", - "collection": "data_assets", } example_record_list = [ @@ -452,6 +450,26 @@ class TestMetadataDbClient(unittest.TestCase): for id_num in range(0, 10) ] + def test_metadatadbclient_constructor(self): + """Tests class constructor""" + client = MetadataDbClient(**self.example_client_args) + + self.assertEqual("example.com", client.host) + self.assertEqual("metadata_index", client.database) + self.assertEqual("data_assets", client.collection) + self.assertEqual("v1", client.version) + self.assertEqual( + "https://example.com/v1/metadata_index/data_assets", + client._base_url, + ) + + client = MetadataDbClient(**self.example_client_args, version="v2") + self.assertEqual("v2", client.version) + self.assertEqual( + "https://example.com/v2/metadata_index/data_assets", + client._base_url, + ) + @patch("aind_data_access_api.document_db.Client._find_records") def test_retrieve_docdb_records( self,