Skip to content

Commit b68a236

Browse files
committed
chore: Add tests
Also return data from an insert with a returning statement.
1 parent bc422a7 commit b68a236

File tree

6 files changed

+516
-119
lines changed

6 files changed

+516
-119
lines changed

.vscode/settings.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"python.testing.pytestArgs": ["tests"],
3+
"python.testing.unittestEnabled": false,
4+
"python.testing.pytestEnabled": true
5+
}

src/mock_alchemy/comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from typing import Optional
1111
from unittest import mock
1212

13-
import sqlalchemy
1413
from packaging import version
14+
from sqlalchemy import __version__ as sqlalchemy_version
1515
from sqlalchemy import delete
1616
from sqlalchemy import func
1717
from sqlalchemy import insert
@@ -35,7 +35,7 @@
3535
ALCHEMY_FUNC_TYPE,
3636
ALCHEMY_LABEL_TYPE,
3737
)
38-
if version.parse(sqlalchemy.__version__) >= version.parse("1.4.0"):
38+
if version.parse(sqlalchemy_version) >= version.parse("1.4.0"):
3939
ALCHEMY_SELECT_TYPE = type(select(column("")))
4040
ALCHEMY_UPDATE_TYPE = type(update(table("")))
4141
ALCHEMY_DELETE_TYPE = type(delete(table("")))

src/mock_alchemy/mocking.py

Lines changed: 162 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from typing import overload
1818
from unittest import mock
1919

20+
from packaging import version
21+
from sqlalchemy import __version__ as sqlalchemy_version
2022
from sqlalchemy import select
21-
from sqlalchemy.exc import ArgumentError
2223
from sqlalchemy.orm.exc import MultipleResultsFound
2324
from sqlalchemy.orm.exc import NoResultFound
2425
from sqlalchemy.sql.dml import Delete
@@ -459,7 +460,6 @@ class UnifiedAlchemyMagicMock(AlchemyMagicMock):
459460
unify: Dict[str, Optional[UnorderedCall]] = {
460461
"add_columns": None,
461462
"distinct": None,
462-
"execute": None,
463463
"filter": UnorderedCall,
464464
"filter_by": UnorderedCall,
465465
"group_by": None,
@@ -473,9 +473,9 @@ class UnifiedAlchemyMagicMock(AlchemyMagicMock):
473473
"where": None,
474474
}
475475

476-
mutate: Set[str] = {"add", "add_all", "delete", "execute"}
476+
mutate: Set[str] = {"add", "add_all", "delete"}
477477

478-
mutate_and_unify: Set[str] = {"execute"}
478+
execute_statement: Set[str] = {"execute"}
479479

480480
@overload
481481
def __init__(
@@ -498,7 +498,12 @@ def __init__(self, *args, **kwargs) -> None:
498498
"""Creates an UnifiedAlchemyMagicMock to mock a SQLAlchemy session."""
499499
kwargs["_mock_default"] = kwargs.pop("default", [])
500500
kwargs["_mock_data"] = kwargs.pop("data", None)
501-
kwargs.update({k: AlchemyMagicMock(side_effect=partial(self._get_data, _mock_name=k)) for k in self.boundary})
501+
kwargs.update(
502+
{
503+
k: AlchemyMagicMock(side_effect=partial(self._get_data, _mock_name=k))
504+
for k in self.boundary
505+
}
506+
)
502507

503508
kwargs.update(
504509
{
@@ -524,9 +529,9 @@ def __init__(self, *args, **kwargs) -> None:
524529
{
525530
k: AlchemyMagicMock(
526531
return_value=self,
527-
side_effect=partial(self._mutate_data, _mock_name=k),
532+
side_effect=partial(self._execute_statement, _mock_name=k),
528533
)
529-
for k in self.mutate_and_unify
534+
for k in self.execute_statement
530535
}
531536
)
532537

@@ -604,13 +609,25 @@ def _get_data(self, *args: Any, **kwargs: Any) -> Any:
604609
_mock_data = self._mock_data
605610
if _mock_data is not None:
606611
previous_calls = [
607-
sqlalchemy_call(i, with_name=True, base_call=self.unify.get(i[0]) or Call)
612+
sqlalchemy_call(
613+
i, with_name=True, base_call=self.unify.get(i[0]) or Call
614+
)
608615
for i in self._get_previous_calls(self.mock_calls[:-1])
609616
]
610617
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
611618
if _mock_name == "get":
612-
query_call = [c for c in previous_calls if c[0] in ["query", "execute"]][0]
613-
results = list(chain(*[result for calls, result in sorted_mock_data if query_call in calls]))
619+
query_call = [
620+
c for c in previous_calls if c[0] in ["query", "execute"]
621+
][0]
622+
results = list(
623+
chain(
624+
*[
625+
result
626+
for calls, result in sorted_mock_data
627+
if query_call in calls
628+
]
629+
)
630+
)
614631
return self.boundary[_mock_name](results, *args, **kwargs)
615632

616633
else:
@@ -636,13 +653,15 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
636653
to_add = args[0]
637654
query_call = mock.call.query(type(to_add))
638655

639-
mocked_data = next(iter(filter(lambda i: i[0] == [query_call], _mock_data)), None)
656+
mocked_data = next(
657+
iter(filter(lambda i: i[0] == [query_call], _mock_data)), None
658+
)
640659
if mocked_data:
641660
mocked_data[1].append(to_add)
642661
else:
643662
_mock_data.append(([query_call], [to_add]))
644663

645-
try:
664+
if version.parse(sqlalchemy_version) >= version.parse("1.4.0"):
646665
execute_call = mock.call.execute(select(type(to_add)))
647666

648667
execute_mocked_data = next(
@@ -658,29 +677,25 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
658677
execute_mocked_data[1].append(to_add)
659678
else:
660679
_mock_data.append(([execute_call], [to_add]))
661-
except (TypeError, ArgumentError):
662-
# TypeError indicates an old version of sqlalcemy that does not support
663-
# executing select(table) statements.
664-
# ArgumentError indicates a mocked table that cannot be selected, so
665-
# cannot be mocked this way.
666-
pass
667-
668680
elif _mock_name == "add_all":
669681
to_add = args[0]
670682
_kwargs = kwargs.copy()
671683
_kwargs["_mock_name"] = "add"
672684

673685
for i in to_add:
674686
self._mutate_data(i, *args[1:], **_kwargs)
675-
elif _mock_name == "delete":
687+
# delete case
688+
else:
676689
_kwargs = kwargs.copy()
677690
# pretend like all is being called to get data
678691
_kwargs["_mock_name"] = "all"
679692
_mock_name = _kwargs.pop("_mock_name")
680693
_mock_data = self._mock_data
681694
num_deleted = 0
682695
previous_calls = [
683-
sqlalchemy_call(i, with_name=True, base_call=self.unify.get(i[0]) or Call)
696+
sqlalchemy_call(
697+
i, with_name=True, base_call=self.unify.get(i[0]) or Call
698+
)
684699
for i in self._get_previous_calls(self.mock_calls[:-1])
685700
]
686701
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
@@ -703,100 +718,134 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
703718
temp_mock_data.append((calls, result))
704719
self._mock_data = temp_mock_data
705720
return num_deleted
706-
# execute case
721+
722+
def _execute_insert(
723+
self, execute_statement: Insert, *args: Any, **kwargs: Any
724+
) -> Any:
725+
"""Insert data from execute statement."""
726+
_kwargs = kwargs.copy()
727+
execute_statement = args[0]
728+
_kwargs["_mock_name"] = "add"
729+
table_type = execute_statement.entity_description["type"]
730+
# Values should either be a list of dictionaries as arg[1] or a list of
731+
# dictionaries as values.
732+
if len(args) > 1:
733+
for i in args[1]:
734+
self._mutate_data(table_type(**i), **_kwargs)
707735
else:
708-
_kwargs = kwargs.copy()
709-
# Need to check if the execute was an insert, update or delete. Ignore any other types
710-
execute_statement = args[0]
711-
712-
if isinstance(execute_statement, Insert):
713-
# Add insert data
714-
_kwargs["_mock_name"] = "add"
715-
table_type = execute_statement.entity_description["type"]
716-
# Values should either be a list of dictionaries ar arg[1] or a list of dictionaries as values.
717-
if len(args) > 1:
718-
for i in args[1]:
719-
self._mutate_data(table_type(**i), **_kwargs)
720-
else:
721-
# Values will be stored within _multi_values list
722-
values = execute_statement._multi_values[0]
723-
for i in values:
724-
self._mutate_data(table_type(**{k.name: v for k, v in i.items()}), **_kwargs)
725-
# Only unify if the insert statement is returning
726-
if execute_statement._returning:
727-
return self._unify(self, *args, **kwargs)
728-
else:
729-
# insert a boundary so that this is no longer part of a unified call.
730-
self.all()
731-
elif isinstance(execute_statement, Delete):
732-
# Create equivalent select statement as an Expression Matcher
733-
select_statement = [
734-
ExpressionMatcher(
735-
mock.call.execute(select(execute_statement.table).where(execute_statement.whereclause))
736-
)
737-
]
738-
_mock_data = self._mock_data
739-
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
740-
temp_mock_data = list()
741-
found_query = False
742-
num_deleted = 0
743-
for calls, result in sorted_mock_data:
744-
calls = [
745-
sqlalchemy_call(
746-
i,
747-
with_name=True,
748-
base_call=self.unify.get(i[0]) or Call,
736+
# Values will be stored within _multi_values list
737+
values = execute_statement._multi_values[0]
738+
for i in values:
739+
self._mutate_data(
740+
table_type(**{k.name: v for k, v in i.items()}), **_kwargs
741+
)
742+
# insert a boundary so that this is no longer part of a unified call.
743+
self.all()
744+
# Start a new unify if the insert statement is returning
745+
if execute_statement._returning:
746+
return self.execute(select(execute_statement._returning[0]))
747+
return None
748+
749+
def _execute_delete(self, execute_statement: Delete, *args: Any) -> mock.Mock:
750+
"""Delete data according to execute statement."""
751+
execute_statement = args[0]
752+
# Create equivalent select statement as an Expression Matcher
753+
select_statement = (
754+
[
755+
ExpressionMatcher(
756+
mock.call.execute(
757+
select(execute_statement.table).where(
758+
execute_statement.whereclause
749759
)
750-
for i in calls
751-
]
752-
if all(c in select_statement for c in calls) and not found_query:
753-
num_deleted = len(result)
754-
temp_mock_data.append((calls, []))
755-
found_query = True
756-
else:
757-
temp_mock_data.append((calls, result))
758-
self._mock_data = temp_mock_data
759-
delete_result = mock.Mock()
760-
delete_result.rowcount = num_deleted
761-
# insert a boundary so that this is no longer part of a unified call.
762-
self.all()
763-
return delete_result
764-
elif isinstance(execute_statement, Update):
765-
# Create equivalent select statement as an Expression Matcher
766-
select_statement = [
767-
ExpressionMatcher(
768-
mock.call.execute(select(execute_statement.table).where(execute_statement.whereclause))
769760
)
770-
]
771-
_mock_data = self._mock_data
772-
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
773-
temp_mock_data = list()
774-
found_query = False
775-
num_updated = 0
776-
for calls, result in sorted_mock_data:
777-
calls = [
778-
sqlalchemy_call(
779-
i,
780-
with_name=True,
781-
base_call=self.unify.get(i[0]) or Call,
761+
)
762+
]
763+
if execute_statement.whereclause is not None
764+
else [ExpressionMatcher(mock.call.execute(select(execute_statement.table)))]
765+
)
766+
_mock_data = self._mock_data = self._mock_data or []
767+
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
768+
temp_mock_data = list()
769+
found_query = False
770+
num_deleted = 0
771+
for calls, result in sorted_mock_data:
772+
calls = [
773+
sqlalchemy_call(
774+
i,
775+
with_name=True,
776+
base_call=self.unify.get(i[0]) or Call,
777+
)
778+
for i in calls
779+
]
780+
if all(c in select_statement for c in calls) and not found_query:
781+
num_deleted = len(result)
782+
temp_mock_data.append((calls, []))
783+
found_query = True
784+
else:
785+
temp_mock_data.append((calls, result))
786+
self._mock_data = temp_mock_data
787+
delete_result = mock.Mock()
788+
delete_result.rowcount = num_deleted
789+
# insert a boundary so that this is no longer part of a unified call.
790+
self.all()
791+
return delete_result
792+
793+
def _execute_update(self, execute_statement: Update) -> mock.Mock:
794+
"""Update data according to execute statement."""
795+
# Create equivalent select statement as an Expression Matcher
796+
select_statement = (
797+
[
798+
ExpressionMatcher(
799+
mock.call.execute(
800+
select(execute_statement.table).where(
801+
execute_statement.whereclause
782802
)
783-
for i in calls
784-
]
785-
if all(c in select_statement for c in calls) and not found_query:
786-
num_updated = len(result)
787-
for r in result:
788-
for k, v in execute_statement._values.items():
789-
setattr(r, k.name, v.value)
790-
temp_mock_data.append((calls, result))
791-
found_query = True
792-
else:
793-
temp_mock_data.append((calls, result))
794-
self._mock_data = temp_mock_data
795-
update_result = mock.Mock()
796-
update_result.rowcount = num_updated
797-
# insert a boundary so that this is no longer part of a unified call.
798-
self.all()
799-
return update_result
803+
)
804+
)
805+
]
806+
if execute_statement.whereclause is not None
807+
else [ExpressionMatcher(mock.call.execute(select(execute_statement.table)))]
808+
)
809+
_mock_data = self._mock_data = self._mock_data or []
810+
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
811+
temp_mock_data = list()
812+
found_query = False
813+
num_updated = 0
814+
for calls, result in sorted_mock_data:
815+
calls = [
816+
sqlalchemy_call(
817+
i,
818+
with_name=True,
819+
base_call=self.unify.get(i[0]) or Call,
820+
)
821+
for i in calls
822+
]
823+
if all(c in select_statement for c in calls) and not found_query:
824+
num_updated = len(result)
825+
for r in result:
826+
for k, v in execute_statement._values.items():
827+
setattr(r, k.name, v.value)
828+
temp_mock_data.append((calls, result))
829+
found_query = True
800830
else:
801-
# assume any other execute types need to unify
802-
return self._unify(self, *args, **kwargs)
831+
temp_mock_data.append((calls, result))
832+
self._mock_data = temp_mock_data
833+
update_result = mock.Mock()
834+
update_result.rowcount = num_updated
835+
# insert a boundary so that this is no longer part of a unified call.
836+
self.all()
837+
return update_result
838+
839+
def _execute_statement(self, *args: Any, **kwargs: Any) -> Any:
840+
"""Depending on statement being executed, update data and/or unify statement."""
841+
# Need to check if the execute was an insert, update or delete.
842+
execute_statement = args[0]
843+
if isinstance(execute_statement, Insert):
844+
return self._execute_insert(execute_statement, *args, **kwargs)
845+
elif isinstance(execute_statement, Delete):
846+
return self._execute_delete(execute_statement, *args)
847+
elif isinstance(execute_statement, Update):
848+
return self._execute_update(execute_statement)
849+
else:
850+
# assume any other execute types need to unify
851+
return self._unify(self, *args, **kwargs)

0 commit comments

Comments
 (0)