Skip to content

Commit d58a0f2

Browse files
nailo2cLee-W
andauthored
feat: add create_collection function to MongoHook (#50518)
* feat: add create_collection function with unit tests * rm comments * fix static checks * make the type of `create_kwargs` clearer Co-authored-by: Wei Lee <[email protected]> * modfiy variable name: create_if_exists -> return_if_exists * move import CollectionInvalid to top --------- Co-authored-by: Wei Lee <[email protected]>
1 parent f1ca1d1 commit d58a0f2

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

providers/mongo/src/airflow/providers/mongo/hooks/mongo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import pymongo
2727
from pymongo import MongoClient, ReplaceOne
28+
from pymongo.errors import CollectionInvalid
2829

2930
from airflow.exceptions import AirflowConfigException
3031
from airflow.hooks.base import BaseHook
@@ -225,6 +226,37 @@ def get_collection(self, mongo_collection: str, mongo_db: str | None = None) ->
225226

226227
return mongo_conn.get_database(mongo_db).get_collection(mongo_collection)
227228

229+
def create_collection(
230+
self,
231+
mongo_collection: str,
232+
mongo_db: str | None = None,
233+
return_if_exists: bool = True,
234+
**create_kwargs: Any,
235+
) -> MongoCollection:
236+
"""
237+
Create the collection (optionally a time‑series collection) and return it.
238+
239+
https://pymongo.readthedocs.io/en/stable/api/pymongo/database.html#pymongo.database.Database.create_collection
240+
241+
:param mongo_collection: Name of the collection.
242+
:param mongo_db: Target database; defaults to the schema in the connection string.
243+
:param return_if_exists: If True and the collection already exists, return it instead of raising.
244+
:param create_kwargs: Additional keyword arguments forwarded to ``db.create_collection()``,
245+
e.g. ``timeseries={...}``, ``capped=True``.
246+
"""
247+
mongo_db = mongo_db or self.connection.schema
248+
mongo_conn: MongoClient = self.get_conn()
249+
db = mongo_conn.get_database(mongo_db)
250+
251+
try:
252+
db.create_collection(mongo_collection, **create_kwargs)
253+
except CollectionInvalid:
254+
if not return_if_exists:
255+
raise
256+
# Collection already exists – fall through and fetch it.
257+
258+
return db.get_collection(mongo_collection)
259+
228260
def aggregate(
229261
self, mongo_collection: str, aggregate_query: list, mongo_db: str | None = None, **kwargs
230262
) -> CommandCursor:

providers/mongo/tests/unit/mongo/hooks/test_mongo.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pymongo
2424
import pytest
25+
from pymongo.errors import CollectionInvalid
2526

2627
from airflow.exceptions import AirflowConfigException
2728
from airflow.models import Connection
@@ -387,6 +388,36 @@ def test_distinct_with_filter(self):
387388
results = self.hook.distinct(collection, "test_id", {"test_status": "failure"})
388389
assert len(results) == 1
389390

391+
def test_create_standard_collection(self):
392+
mock_client = mongomock.MongoClient()
393+
self.hook.get_conn = lambda: mock_client
394+
self.hook.connection.schema = "test_db"
395+
396+
collection = self.hook.create_collection(mongo_collection="plain_collection")
397+
assert collection.name == "plain_collection"
398+
assert "plain_collection" in mock_client["test_db"].list_collection_names()
399+
400+
def test_return_if_exists_true_returns_existing(self):
401+
mock_client = mongomock.MongoClient()
402+
self.hook.get_conn = lambda: mock_client
403+
self.hook.connection.schema = "test_db"
404+
405+
first = self.hook.create_collection(mongo_collection="foo")
406+
second = self.hook.create_collection(mongo_collection="foo", return_if_exists=True)
407+
408+
assert first.full_name == second.full_name
409+
assert "foo" in mock_client["test_db"].list_collection_names()
410+
411+
def test_return_if_exists_false_raises(self):
412+
mock_client = mongomock.MongoClient()
413+
self.hook.get_conn = lambda: mock_client
414+
self.hook.connection.schema = "test_db"
415+
416+
self.hook.create_collection(mongo_collection="bar")
417+
418+
with pytest.raises(CollectionInvalid):
419+
self.hook.create_collection(mongo_collection="bar", return_if_exists=False)
420+
390421

391422
def test_context_manager():
392423
with MongoHook(mongo_conn_id="mongo_default") as ctx_hook:

0 commit comments

Comments
 (0)