Skip to content

Commit

Permalink
Add Mongo projections to hook and transfer (apache#17379)
Browse files Browse the repository at this point in the history
  • Loading branch information
JavierLopezT authored Aug 12, 2021
1 parent 4598eb0 commit 9875757
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
7 changes: 7 additions & 0 deletions airflow/providers/amazon/aws/transfers/mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class MongoToS3Operator(BaseOperator):
:type mongo_collection: str
:param mongo_query: query to execute. A list including a dict of the query
:type mongo_query: Union[list, dict]
:param mongo_projection: optional parameter to filter the returned fields by
the query. It can be a list of fields names to include or a dictionary
for excluding fields (e.g `projection={"_id": 0}`
:type mongo_projection: Union[list, dict]
:param s3_bucket: reference to a specific S3 bucket to store the data
:type s3_bucket: str
:param s3_key: in which S3 key the file will be stored
Expand Down Expand Up @@ -71,6 +75,7 @@ def __init__(
s3_bucket: str,
s3_key: str,
mongo_db: Optional[str] = None,
mongo_projection: Optional[Union[list, dict]] = None,
replace: bool = False,
allow_disk_use: bool = False,
compression: Optional[str] = None,
Expand All @@ -89,6 +94,7 @@ def __init__(
# Grab query and determine if we need to run an aggregate pipeline
self.mongo_query = mongo_query
self.is_pipeline = isinstance(self.mongo_query, list)
self.mongo_projection = mongo_projection

self.s3_bucket = s3_bucket
self.s3_key = s3_key
Expand All @@ -113,6 +119,7 @@ def execute(self, context) -> bool:
results = MongoHook(self.mongo_conn_id).find(
mongo_collection=self.mongo_collection,
query=cast(dict, self.mongo_query),
projection=self.mongo_projection,
mongo_db=self.mongo_db,
)

Expand Down
29 changes: 15 additions & 14 deletions airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Hook for Mongo DB"""
from ssl import CERT_NONE
from types import TracebackType
from typing import List, Optional, Type
from typing import List, Optional, Type, Union

import pymongo
from pymongo import MongoClient, ReplaceOne
Expand Down Expand Up @@ -122,8 +122,8 @@ def aggregate(
) -> pymongo.command_cursor.CommandCursor:
"""
Runs an aggregation pipeline and returns the results
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.aggregate
https://api.mongodb.com/python/current/examples/aggregation.html
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.aggregate
https://pymongo.readthedocs.io/en/stable/examples/aggregation.html
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

Expand All @@ -135,25 +135,26 @@ def find(
query: dict,
find_one: bool = False,
mongo_db: Optional[str] = None,
projection: Optional[Union[list, dict]] = None,
**kwargs,
) -> pymongo.cursor.Cursor:
"""
Runs a mongo find query and returns the results
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.find
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

if find_one:
return collection.find_one(query, **kwargs)
return collection.find_one(query, projection, **kwargs)
else:
return collection.find(query, **kwargs)
return collection.find(query, projection, **kwargs)

def insert_one(
self, mongo_collection: str, doc: dict, mongo_db: Optional[str] = None, **kwargs
) -> pymongo.results.InsertOneResult:
"""
Inserts a single document into a mongo collection
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_one
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_one
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

Expand All @@ -164,7 +165,7 @@ def insert_many(
) -> pymongo.results.InsertManyResult:
"""
Inserts many docs into a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_many
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_many
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

Expand All @@ -180,7 +181,7 @@ def update_one(
) -> pymongo.results.UpdateResult:
"""
Updates a single document in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_one
:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
Expand All @@ -207,7 +208,7 @@ def update_many(
) -> pymongo.results.UpdateResult:
"""
Updates one or more documents in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_many
:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
Expand All @@ -234,7 +235,7 @@ def replace_one(
) -> pymongo.results.UpdateResult:
"""
Replaces a single document in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
.. note::
If no ``filter_doc`` is given, it is assumed that the replacement
Expand Down Expand Up @@ -272,7 +273,7 @@ def replace_many(
Replaces many documents in a mongo collection.
Uses bulk_write with multiple ReplaceOne operations
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
.. note::
If no ``filter_docs``are given, it is assumed that all
Expand Down Expand Up @@ -314,7 +315,7 @@ def delete_one(
) -> pymongo.results.DeleteResult:
"""
Deletes a single document in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
:param mongo_collection: The name of the collection to delete from.
:type mongo_collection: str
Expand All @@ -334,7 +335,7 @@ def delete_many(
) -> pymongo.results.DeleteResult:
"""
Deletes one or more documents in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
:param mongo_collection: The name of the collection to delete from.
:type mongo_collection: str
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_execute(self, mock_s3_hook, mock_mongo_hook):
operator.execute(None)

mock_mongo_hook.return_value.find.assert_called_once_with(
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
)

op_stringify = self.mock_operator._stringify
Expand All @@ -117,7 +117,7 @@ def test_execute_compress(self, mock_s3_hook, mock_mongo_hook):
operator.execute(None)

mock_mongo_hook.return_value.find.assert_called_once_with(
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
)

op_stringify = self.mock_operator._stringify
Expand Down
20 changes: 18 additions & 2 deletions tests/providers/mongo/hooks/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,29 @@ def test_find_one(self):
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_find_many(self):
collection = mongomock.MongoClient().db.collection
objs = [{'test_find_many_1': 'test_value'}, {'test_find_many_2': 'test_value'}]
objs = [{'_id': 1, 'test_find_many_1': 'test_value'}, {'_id': 2, 'test_find_many_2': 'test_value'}]
collection.insert(objs)

result_objs = self.hook.find(collection, {}, find_one=False)
result_objs = self.hook.find(mongo_collection=collection, query={}, projection={}, find_one=False)

assert len(list(result_objs)) > 1

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_find_many_with_projection(self):
collection = mongomock.MongoClient().db.collection
objs = [
{'_id': '1', 'test_find_many_1': 'test_value', 'field_3': 'a'},
{'_id': '2', 'test_find_many_2': 'test_value', 'field_3': 'b'},
]
collection.insert(objs)

projection = {'_id': 0}
result_objs = self.hook.find(
mongo_collection=collection, query={}, projection=projection, find_one=False
)

self.assertRaises(KeyError, lambda x: x[0]['_id'], result_objs)

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_aggregate(self):
collection = mongomock.MongoClient().db.collection
Expand Down

0 comments on commit 9875757

Please sign in to comment.