1717from typing import overload
1818from unittest import mock
1919
20+ from packaging import version
21+ from sqlalchemy import __version__ as sqlalchemy_version
2022from sqlalchemy import select
21- from sqlalchemy .exc import ArgumentError
2223from sqlalchemy .orm .exc import MultipleResultsFound
2324from sqlalchemy .orm .exc import NoResultFound
2425from sqlalchemy .sql .dml import Delete
@@ -459,7 +460,6 @@ class UnifiedAlchemyMagicMock(AlchemyMagicMock):
459460 unify : Dict [str , Optional [UnorderedCall ]] = {
460461 "add_columns" : None ,
461462 "distinct" : None ,
462- "execute" : None ,
463463 "filter" : UnorderedCall ,
464464 "filter_by" : UnorderedCall ,
465465 "group_by" : None ,
@@ -473,9 +473,9 @@ class UnifiedAlchemyMagicMock(AlchemyMagicMock):
473473 "where" : None ,
474474 }
475475
476- mutate : Set [str ] = {"add" , "add_all" , "delete" , "execute" }
476+ mutate : Set [str ] = {"add" , "add_all" , "delete" }
477477
478- mutate_and_unify : Set [str ] = {"execute" }
478+ execute_statement : Set [str ] = {"execute" }
479479
480480 @overload
481481 def __init__ (
@@ -498,7 +498,12 @@ def __init__(self, *args, **kwargs) -> None:
498498 """Creates an UnifiedAlchemyMagicMock to mock a SQLAlchemy session."""
499499 kwargs ["_mock_default" ] = kwargs .pop ("default" , [])
500500 kwargs ["_mock_data" ] = kwargs .pop ("data" , None )
501- kwargs .update ({k : AlchemyMagicMock (side_effect = partial (self ._get_data , _mock_name = k )) for k in self .boundary })
501+ kwargs .update (
502+ {
503+ k : AlchemyMagicMock (side_effect = partial (self ._get_data , _mock_name = k ))
504+ for k in self .boundary
505+ }
506+ )
502507
503508 kwargs .update (
504509 {
@@ -524,9 +529,9 @@ def __init__(self, *args, **kwargs) -> None:
524529 {
525530 k : AlchemyMagicMock (
526531 return_value = self ,
527- side_effect = partial (self ._mutate_data , _mock_name = k ),
532+ side_effect = partial (self ._execute_statement , _mock_name = k ),
528533 )
529- for k in self .mutate_and_unify
534+ for k in self .execute_statement
530535 }
531536 )
532537
@@ -604,13 +609,25 @@ def _get_data(self, *args: Any, **kwargs: Any) -> Any:
604609 _mock_data = self ._mock_data
605610 if _mock_data is not None :
606611 previous_calls = [
607- sqlalchemy_call (i , with_name = True , base_call = self .unify .get (i [0 ]) or Call )
612+ sqlalchemy_call (
613+ i , with_name = True , base_call = self .unify .get (i [0 ]) or Call
614+ )
608615 for i in self ._get_previous_calls (self .mock_calls [:- 1 ])
609616 ]
610617 sorted_mock_data = sorted (_mock_data , key = lambda x : len (x [0 ]), reverse = True )
611618 if _mock_name == "get" :
612- query_call = [c for c in previous_calls if c [0 ] in ["query" , "execute" ]][0 ]
613- results = list (chain (* [result for calls , result in sorted_mock_data if query_call in calls ]))
619+ query_call = [
620+ c for c in previous_calls if c [0 ] in ["query" , "execute" ]
621+ ][0 ]
622+ results = list (
623+ chain (
624+ * [
625+ result
626+ for calls , result in sorted_mock_data
627+ if query_call in calls
628+ ]
629+ )
630+ )
614631 return self .boundary [_mock_name ](results , * args , ** kwargs )
615632
616633 else :
@@ -636,13 +653,15 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
636653 to_add = args [0 ]
637654 query_call = mock .call .query (type (to_add ))
638655
639- mocked_data = next (iter (filter (lambda i : i [0 ] == [query_call ], _mock_data )), None )
656+ mocked_data = next (
657+ iter (filter (lambda i : i [0 ] == [query_call ], _mock_data )), None
658+ )
640659 if mocked_data :
641660 mocked_data [1 ].append (to_add )
642661 else :
643662 _mock_data .append (([query_call ], [to_add ]))
644663
645- try :
664+ if version . parse ( sqlalchemy_version ) >= version . parse ( "1.4.0" ) :
646665 execute_call = mock .call .execute (select (type (to_add )))
647666
648667 execute_mocked_data = next (
@@ -658,29 +677,25 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
658677 execute_mocked_data [1 ].append (to_add )
659678 else :
660679 _mock_data .append (([execute_call ], [to_add ]))
661- except (TypeError , ArgumentError ):
662- # TypeError indicates an old version of sqlalcemy that does not support
663- # executing select(table) statements.
664- # ArgumentError indicates a mocked table that cannot be selected, so
665- # cannot be mocked this way.
666- pass
667-
668680 elif _mock_name == "add_all" :
669681 to_add = args [0 ]
670682 _kwargs = kwargs .copy ()
671683 _kwargs ["_mock_name" ] = "add"
672684
673685 for i in to_add :
674686 self ._mutate_data (i , * args [1 :], ** _kwargs )
675- elif _mock_name == "delete" :
687+ # delete case
688+ else :
676689 _kwargs = kwargs .copy ()
677690 # pretend like all is being called to get data
678691 _kwargs ["_mock_name" ] = "all"
679692 _mock_name = _kwargs .pop ("_mock_name" )
680693 _mock_data = self ._mock_data
681694 num_deleted = 0
682695 previous_calls = [
683- sqlalchemy_call (i , with_name = True , base_call = self .unify .get (i [0 ]) or Call )
696+ sqlalchemy_call (
697+ i , with_name = True , base_call = self .unify .get (i [0 ]) or Call
698+ )
684699 for i in self ._get_previous_calls (self .mock_calls [:- 1 ])
685700 ]
686701 sorted_mock_data = sorted (_mock_data , key = lambda x : len (x [0 ]), reverse = True )
@@ -703,100 +718,134 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
703718 temp_mock_data .append ((calls , result ))
704719 self ._mock_data = temp_mock_data
705720 return num_deleted
706- # execute case
721+
722+ def _execute_insert (
723+ self , execute_statement : Insert , * args : Any , ** kwargs : Any
724+ ) -> Any :
725+ """Insert data from execute statement."""
726+ _kwargs = kwargs .copy ()
727+ execute_statement = args [0 ]
728+ _kwargs ["_mock_name" ] = "add"
729+ table_type = execute_statement .entity_description ["type" ]
730+ # Values should either be a list of dictionaries as arg[1] or a list of
731+ # dictionaries as values.
732+ if len (args ) > 1 :
733+ for i in args [1 ]:
734+ self ._mutate_data (table_type (** i ), ** _kwargs )
707735 else :
708- _kwargs = kwargs .copy ()
709- # Need to check if the execute was an insert, update or delete. Ignore any other types
710- execute_statement = args [0 ]
711-
712- if isinstance (execute_statement , Insert ):
713- # Add insert data
714- _kwargs ["_mock_name" ] = "add"
715- table_type = execute_statement .entity_description ["type" ]
716- # Values should either be a list of dictionaries ar arg[1] or a list of dictionaries as values.
717- if len (args ) > 1 :
718- for i in args [1 ]:
719- self ._mutate_data (table_type (** i ), ** _kwargs )
720- else :
721- # Values will be stored within _multi_values list
722- values = execute_statement ._multi_values [0 ]
723- for i in values :
724- self ._mutate_data (table_type (** {k .name : v for k , v in i .items ()}), ** _kwargs )
725- # Only unify if the insert statement is returning
726- if execute_statement ._returning :
727- return self ._unify (self , * args , ** kwargs )
728- else :
729- # insert a boundary so that this is no longer part of a unified call.
730- self .all ()
731- elif isinstance (execute_statement , Delete ):
732- # Create equivalent select statement as an Expression Matcher
733- select_statement = [
734- ExpressionMatcher (
735- mock .call .execute (select (execute_statement .table ).where (execute_statement .whereclause ))
736- )
737- ]
738- _mock_data = self ._mock_data
739- sorted_mock_data = sorted (_mock_data , key = lambda x : len (x [0 ]), reverse = True )
740- temp_mock_data = list ()
741- found_query = False
742- num_deleted = 0
743- for calls , result in sorted_mock_data :
744- calls = [
745- sqlalchemy_call (
746- i ,
747- with_name = True ,
748- base_call = self .unify .get (i [0 ]) or Call ,
736+ # Values will be stored within _multi_values list
737+ values = execute_statement ._multi_values [0 ]
738+ for i in values :
739+ self ._mutate_data (
740+ table_type (** {k .name : v for k , v in i .items ()}), ** _kwargs
741+ )
742+ # insert a boundary so that this is no longer part of a unified call.
743+ self .all ()
744+ # Start a new unify if the insert statement is returning
745+ if execute_statement ._returning :
746+ return self .execute (select (execute_statement ._returning [0 ]))
747+ return None
748+
749+ def _execute_delete (self , execute_statement : Delete , * args : Any ) -> mock .Mock :
750+ """Delete data according to execute statement."""
751+ execute_statement = args [0 ]
752+ # Create equivalent select statement as an Expression Matcher
753+ select_statement = (
754+ [
755+ ExpressionMatcher (
756+ mock .call .execute (
757+ select (execute_statement .table ).where (
758+ execute_statement .whereclause
749759 )
750- for i in calls
751- ]
752- if all (c in select_statement for c in calls ) and not found_query :
753- num_deleted = len (result )
754- temp_mock_data .append ((calls , []))
755- found_query = True
756- else :
757- temp_mock_data .append ((calls , result ))
758- self ._mock_data = temp_mock_data
759- delete_result = mock .Mock ()
760- delete_result .rowcount = num_deleted
761- # insert a boundary so that this is no longer part of a unified call.
762- self .all ()
763- return delete_result
764- elif isinstance (execute_statement , Update ):
765- # Create equivalent select statement as an Expression Matcher
766- select_statement = [
767- ExpressionMatcher (
768- mock .call .execute (select (execute_statement .table ).where (execute_statement .whereclause ))
769760 )
770- ]
771- _mock_data = self ._mock_data
772- sorted_mock_data = sorted (_mock_data , key = lambda x : len (x [0 ]), reverse = True )
773- temp_mock_data = list ()
774- found_query = False
775- num_updated = 0
776- for calls , result in sorted_mock_data :
777- calls = [
778- sqlalchemy_call (
779- i ,
780- with_name = True ,
781- base_call = self .unify .get (i [0 ]) or Call ,
761+ )
762+ ]
763+ if execute_statement .whereclause is not None
764+ else [ExpressionMatcher (mock .call .execute (select (execute_statement .table )))]
765+ )
766+ _mock_data = self ._mock_data = self ._mock_data or []
767+ sorted_mock_data = sorted (_mock_data , key = lambda x : len (x [0 ]), reverse = True )
768+ temp_mock_data = list ()
769+ found_query = False
770+ num_deleted = 0
771+ for calls , result in sorted_mock_data :
772+ calls = [
773+ sqlalchemy_call (
774+ i ,
775+ with_name = True ,
776+ base_call = self .unify .get (i [0 ]) or Call ,
777+ )
778+ for i in calls
779+ ]
780+ if all (c in select_statement for c in calls ) and not found_query :
781+ num_deleted = len (result )
782+ temp_mock_data .append ((calls , []))
783+ found_query = True
784+ else :
785+ temp_mock_data .append ((calls , result ))
786+ self ._mock_data = temp_mock_data
787+ delete_result = mock .Mock ()
788+ delete_result .rowcount = num_deleted
789+ # insert a boundary so that this is no longer part of a unified call.
790+ self .all ()
791+ return delete_result
792+
793+ def _execute_update (self , execute_statement : Update ) -> mock .Mock :
794+ """Update data according to execute statement."""
795+ # Create equivalent select statement as an Expression Matcher
796+ select_statement = (
797+ [
798+ ExpressionMatcher (
799+ mock .call .execute (
800+ select (execute_statement .table ).where (
801+ execute_statement .whereclause
782802 )
783- for i in calls
784- ]
785- if all (c in select_statement for c in calls ) and not found_query :
786- num_updated = len (result )
787- for r in result :
788- for k , v in execute_statement ._values .items ():
789- setattr (r , k .name , v .value )
790- temp_mock_data .append ((calls , result ))
791- found_query = True
792- else :
793- temp_mock_data .append ((calls , result ))
794- self ._mock_data = temp_mock_data
795- update_result = mock .Mock ()
796- update_result .rowcount = num_updated
797- # insert a boundary so that this is no longer part of a unified call.
798- self .all ()
799- return update_result
803+ )
804+ )
805+ ]
806+ if execute_statement .whereclause is not None
807+ else [ExpressionMatcher (mock .call .execute (select (execute_statement .table )))]
808+ )
809+ _mock_data = self ._mock_data = self ._mock_data or []
810+ sorted_mock_data = sorted (_mock_data , key = lambda x : len (x [0 ]), reverse = True )
811+ temp_mock_data = list ()
812+ found_query = False
813+ num_updated = 0
814+ for calls , result in sorted_mock_data :
815+ calls = [
816+ sqlalchemy_call (
817+ i ,
818+ with_name = True ,
819+ base_call = self .unify .get (i [0 ]) or Call ,
820+ )
821+ for i in calls
822+ ]
823+ if all (c in select_statement for c in calls ) and not found_query :
824+ num_updated = len (result )
825+ for r in result :
826+ for k , v in execute_statement ._values .items ():
827+ setattr (r , k .name , v .value )
828+ temp_mock_data .append ((calls , result ))
829+ found_query = True
800830 else :
801- # assume any other execute types need to unify
802- return self ._unify (self , * args , ** kwargs )
831+ temp_mock_data .append ((calls , result ))
832+ self ._mock_data = temp_mock_data
833+ update_result = mock .Mock ()
834+ update_result .rowcount = num_updated
835+ # insert a boundary so that this is no longer part of a unified call.
836+ self .all ()
837+ return update_result
838+
839+ def _execute_statement (self , * args : Any , ** kwargs : Any ) -> Any :
840+ """Depending on statement being executed, update data and/or unify statement."""
841+ # Need to check if the execute was an insert, update or delete.
842+ execute_statement = args [0 ]
843+ if isinstance (execute_statement , Insert ):
844+ return self ._execute_insert (execute_statement , * args , ** kwargs )
845+ elif isinstance (execute_statement , Delete ):
846+ return self ._execute_delete (execute_statement , * args )
847+ elif isinstance (execute_statement , Update ):
848+ return self ._execute_update (execute_statement )
849+ else :
850+ # assume any other execute types need to unify
851+ return self ._unify (self , * args , ** kwargs )
0 commit comments