Skip to content

Make pandas/io/sql.py work with sqlalchemy 2.0 #48576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/deps/actions-310.yaml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
2 changes: 1 addition & 1 deletion ci/deps/actions-311.yaml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
2 changes: 1 addition & 1 deletion ci/deps/actions-38-downstream_compat.yaml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- xarray
- xlrd
2 changes: 1 addition & 1 deletion ci/deps/actions-38.yaml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- xarray
- xlrd
2 changes: 1 addition & 1 deletion ci/deps/actions-39.yaml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
2 changes: 1 addition & 1 deletion ci/deps/circle-38-arm64.yaml
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- xarray
- xlrd
4 changes: 2 additions & 2 deletions doc/source/user_guide/io.rst
Original file line number Diff line number Diff line change
@@ -5868,15 +5868,15 @@ If you have an SQLAlchemy description of your database you can express where con
sa.Column("Col_3", sa.Boolean),
)

pd.read_sql(sa.select([data_table]).where(data_table.c.Col_3 is True), engine)
pd.read_sql(sa.select(data_table).where(data_table.c.Col_3 is True), engine)

You can combine SQLAlchemy expressions with parameters passed to :func:`read_sql` using :func:`sqlalchemy.bindparam`

.. ipython:: python

import datetime as dt

expr = sa.select([data_table]).where(data_table.c.Date > sa.bindparam("date"))
expr = sa.select(data_table).where(data_table.c.Date > sa.bindparam("date"))
pd.read_sql(expr, engine, params={"date": dt.datetime(2010, 10, 18)})


1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
@@ -291,6 +291,7 @@ Other enhancements
- Improved error message when trying to align :class:`DataFrame` objects (for example, in :func:`DataFrame.compare`) to clarify that "identically labelled" refers to both index and columns (:issue:`50083`)
- Added :meth:`DatetimeIndex.as_unit` and :meth:`TimedeltaIndex.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`50616`)
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
- Added support for SQLAlchemy 2.0 (:issue:`40686`)
-

.. ---------------------------------------------------------------------------
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
17 changes: 11 additions & 6 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
@@ -2709,7 +2709,7 @@ def to_sql(
library. Legacy support is provided for sqlite3.Connection objects. The user
is responsible for engine disposal and connection closure for the SQLAlchemy
connectable. See `here \
<https://docs.sqlalchemy.org/en/14/core/connections.html>`_.
<https://docs.sqlalchemy.org/en/20/core/connections.html>`_.
If passing a sqlalchemy.engine.Connection which is already in a transaction,
the transaction will not be committed. If passing a sqlite3.Connection,
it will not be possible to roll back the record insertion.
@@ -2759,7 +2759,7 @@ def to_sql(
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
reflect the exact number of written rows as stipulated in the
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__.
`SQLAlchemy <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.rowcount>`__.

.. versionadded:: 1.4.0

@@ -2803,7 +2803,9 @@ def to_sql(

>>> df.to_sql('users', con=engine)
3
>>> engine.execute("SELECT * FROM users").fetchall()
>>> from sqlalchemy import text
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3')]

An `sqlalchemy.engine.Connection` can also be passed to `con`:
@@ -2819,7 +2821,8 @@ def to_sql(
>>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']})
>>> df2.to_sql('users', con=engine, if_exists='append')
2
>>> engine.execute("SELECT * FROM users").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3'),
(0, 'User 4'), (1, 'User 5'), (0, 'User 6'),
(1, 'User 7')]
@@ -2829,7 +2832,8 @@ def to_sql(
>>> df2.to_sql('users', con=engine, if_exists='replace',
... index_label='id')
2
>>> engine.execute("SELECT * FROM users").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 6'), (1, 'User 7')]

Specify the dtype (especially useful for integers with missing values).
@@ -2849,7 +2853,8 @@ def to_sql(
... dtype={"A": Integer()})
3

>>> engine.execute("SELECT * FROM integers").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM integers")).fetchall()
[(1,), (None,), (2,)]
""" # noqa:E501
from pandas.io import sql
84 changes: 45 additions & 39 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
@@ -69,23 +69,16 @@

if TYPE_CHECKING:
from sqlalchemy import Table
from sqlalchemy.sql.expression import (
Select,
TextClause,
)


# -----------------------------------------------------------------------------
# -- Helper functions


def _convert_params(sql, params):
"""Convert SQL and params args to DBAPI2.0 compliant format."""
args = [sql]
if params is not None:
if hasattr(params, "keys"): # test if params is a mapping
args += [params]
else:
args += [list(params)]
return args


def _process_parse_dates_argument(parse_dates):
"""Process parse_dates argument for read_sql functions"""
# handle non-list entries for parse_dates gracefully
@@ -224,8 +217,7 @@ def execute(sql, con, params=None):
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)):
raise TypeError("pandas.io.sql.execute requires a connection") # GH50185
with pandasSQL_builder(con, need_transaction=True) as pandas_sql:
args = _convert_params(sql, params)
return pandas_sql.execute(*args)
return pandas_sql.execute(sql, params)


# -----------------------------------------------------------------------------
@@ -348,7 +340,7 @@ def read_sql_table(
else using_nullable_dtypes()
)

with pandasSQL_builder(con, schema=schema) as pandas_sql:
with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql:
if not pandas_sql.has_table(table_name):
raise ValueError(f"Table {table_name} not found")

@@ -951,7 +943,8 @@ def sql_schema(self) -> str:
def _execute_create(self) -> None:
# Inserting table into database, add to MetaData object
self.table = self.table.to_metadata(self.pd_sql.meta)
self.table.create(bind=self.pd_sql.con)
with self.pd_sql.run_transaction():
self.table.create(bind=self.pd_sql.con)

def create(self) -> None:
if self.exists():
@@ -1221,7 +1214,7 @@ def _create_table_setup(self):

column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type)

columns = [
columns: list[Any] = [
Column(name, typ, index=is_index)
for name, typ, is_index in column_names_and_types
]
@@ -1451,7 +1444,7 @@ def to_sql(
pass

@abstractmethod
def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
pass

@abstractmethod
@@ -1511,7 +1504,7 @@ def insert_records(

try:
return table.insert(chunksize=chunksize, method=method)
except exc.SQLAlchemyError as err:
except exc.StatementError as err:
# GH34431
# https://stackoverflow.com/a/67358288/6067848
msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?#
@@ -1579,13 +1572,18 @@ def __init__(
from sqlalchemy.engine import Engine
from sqlalchemy.schema import MetaData

# self.exit_stack cleans up the Engine and Connection and commits the
# transaction if any of those objects was created below.
# Cleanup happens either in self.__exit__ or at the end of the iterator
# returned by read_sql when chunksize is not None.
self.exit_stack = ExitStack()
if isinstance(con, str):
con = create_engine(con)
self.exit_stack.callback(con.dispose)
if isinstance(con, Engine):
con = self.exit_stack.enter_context(con.connect())
if need_transaction:
self.exit_stack.enter_context(con.begin())
if need_transaction and not con.in_transaction():
self.exit_stack.enter_context(con.begin())
self.con = con
self.meta = MetaData(schema=schema)
self.returns_generator = False
@@ -1596,11 +1594,18 @@ def __exit__(self, *args) -> None:

@contextmanager
def run_transaction(self):
yield self.con
if not self.con.in_transaction():
with self.con.begin():
yield self.con
else:
yield self.con

def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
"""Simple passthrough to SQLAlchemy connectable"""
return self.con.execute(*args, **kwargs)
args = [] if params is None else [params]
if isinstance(sql, str):
return self.con.exec_driver_sql(sql, *args)
return self.con.execute(sql, *args)

def read_table(
self,
@@ -1780,9 +1785,7 @@ def read_query(
read_sql

"""
args = _convert_params(sql, params)

result = self.execute(*args)
result = self.execute(sql, params)
columns = result.keys()

if chunksize is not None:
@@ -1838,13 +1841,14 @@ def prep_table(
else:
dtype = cast(dict, dtype)

from sqlalchemy.types import (
TypeEngine,
to_instance,
)
from sqlalchemy.types import TypeEngine

for col, my_type in dtype.items():
if not isinstance(to_instance(my_type), TypeEngine):
if isinstance(my_type, type) and issubclass(my_type, TypeEngine):
pass
elif isinstance(my_type, TypeEngine):
pass
else:
raise ValueError(f"The type of {col} is not a SQLAlchemy type")

table = SQLTable(
@@ -2005,7 +2009,8 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
schema = schema or self.meta.schema
if self.has_table(table_name, schema):
self.meta.reflect(bind=self.con, only=[table_name], schema=schema)
self.get_table(table_name, schema).drop(bind=self.con)
with self.run_transaction():
self.get_table(table_name, schema).drop(bind=self.con)
self.meta.clear()

def _create_sql_schema(
@@ -2238,21 +2243,24 @@ def run_transaction(self):
finally:
cur.close()

def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
if not isinstance(sql, str):
raise TypeError("Query must be a string unless using sqlalchemy.")
args = [] if params is None else [params]
cur = self.con.cursor()
try:
cur.execute(*args, **kwargs)
cur.execute(sql, *args)
return cur
except Exception as exc:
try:
self.con.rollback()
except Exception as inner_exc: # pragma: no cover
ex = DatabaseError(
f"Execution failed on sql: {args[0]}\n{exc}\nunable to rollback"
f"Execution failed on sql: {sql}\n{exc}\nunable to rollback"
)
raise ex from inner_exc

ex = DatabaseError(f"Execution failed on sql '{args[0]}': {exc}")
ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}")
raise ex from exc

@staticmethod
@@ -2305,9 +2313,7 @@ def read_query(
dtype: DtypeArg | None = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | Iterator[DataFrame]:

args = _convert_params(sql, params)
cursor = self.execute(*args)
cursor = self.execute(sql, params)
columns = [col_desc[0] for col_desc in cursor.description]

if chunksize is not None:
159 changes: 121 additions & 38 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
@@ -149,8 +149,6 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str):
from sqlalchemy.engine import Engine

iris = iris_table_metadata(dialect)
iris.drop(conn, checkfirst=True)
iris.create(bind=conn)

with iris_file.open(newline=None) as csvfile:
reader = csv.reader(csvfile)
@@ -160,9 +158,14 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str):
if isinstance(conn, Engine):
with conn.connect() as conn:
with conn.begin():
iris.drop(conn, checkfirst=True)
iris.create(bind=conn)
conn.execute(stmt)
else:
conn.execute(stmt)
with conn.begin():
iris.drop(conn, checkfirst=True)
iris.create(bind=conn)
conn.execute(stmt)


def create_and_load_iris_view(conn):
@@ -180,7 +183,8 @@ def create_and_load_iris_view(conn):
with conn.begin():
conn.execute(stmt)
else:
conn.execute(stmt)
with conn.begin():
conn.execute(stmt)


def types_table_metadata(dialect: str):
@@ -243,16 +247,19 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str):
from sqlalchemy.engine import Engine

types = types_table_metadata(dialect)
types.drop(conn, checkfirst=True)
types.create(bind=conn)

stmt = insert(types).values(types_data)
if isinstance(conn, Engine):
with conn.connect() as conn:
with conn.begin():
types.drop(conn, checkfirst=True)
types.create(bind=conn)
conn.execute(stmt)
else:
conn.execute(stmt)
with conn.begin():
types.drop(conn, checkfirst=True)
types.create(bind=conn)
conn.execute(stmt)


def check_iris_frame(frame: DataFrame):
@@ -269,25 +276,21 @@ def count_rows(conn, table_name: str):
cur = conn.cursor()
return cur.execute(stmt).fetchone()[0]
else:
from sqlalchemy import (
create_engine,
text,
)
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine

stmt = text(stmt)
if isinstance(conn, str):
try:
engine = create_engine(conn)
with engine.connect() as conn:
return conn.execute(stmt).scalar_one()
return conn.exec_driver_sql(stmt).scalar_one()
finally:
engine.dispose()
elif isinstance(conn, Engine):
with conn.connect() as conn:
return conn.execute(stmt).scalar_one()
return conn.exec_driver_sql(stmt).scalar_one()
else:
return conn.execute(stmt).scalar_one()
return conn.exec_driver_sql(stmt).scalar_one()


@pytest.fixture
@@ -417,7 +420,8 @@ def mysql_pymysql_engine(iris_path, types_data):

@pytest.fixture
def mysql_pymysql_conn(mysql_pymysql_engine):
yield mysql_pymysql_engine.connect()
with mysql_pymysql_engine.connect() as conn:
yield conn


@pytest.fixture
@@ -443,7 +447,8 @@ def postgresql_psycopg2_engine(iris_path, types_data):

@pytest.fixture
def postgresql_psycopg2_conn(postgresql_psycopg2_engine):
yield postgresql_psycopg2_engine.connect()
with postgresql_psycopg2_engine.connect() as conn:
yield conn


@pytest.fixture
@@ -463,7 +468,8 @@ def sqlite_engine(sqlite_str):

@pytest.fixture
def sqlite_conn(sqlite_engine):
yield sqlite_engine.connect()
with sqlite_engine.connect() as conn:
yield conn


@pytest.fixture
@@ -483,7 +489,8 @@ def sqlite_iris_engine(sqlite_engine, iris_path):

@pytest.fixture
def sqlite_iris_conn(sqlite_iris_engine):
yield sqlite_iris_engine.connect()
with sqlite_iris_engine.connect() as conn:
yield conn


@pytest.fixture
@@ -533,12 +540,20 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path):
all_connectable_iris = sqlalchemy_connectable_iris + ["sqlite_buildin_iris"]


@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable)
def test_dataframe_to_sql(conn, test_frame1, request):
# GH 51086 if conn is sqlite_engine
conn = request.getfixturevalue(conn)
test_frame1.to_sql("test", conn, if_exists="append", index=False)


@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable)
@pytest.mark.parametrize("method", [None, "multi"])
def test_to_sql(conn, method, test_frame1, request):
conn = request.getfixturevalue(conn)
with pandasSQL_builder(conn) as pandasSQL:
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
pandasSQL.to_sql(test_frame1, "test_frame", method=method)
assert pandasSQL.has_table("test_frame")
assert count_rows(conn, "test_frame") == len(test_frame1)
@@ -549,7 +564,7 @@ def test_to_sql(conn, method, test_frame1, request):
@pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)])
def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
conn = request.getfixturevalue(conn)
with pandasSQL_builder(conn) as pandasSQL:
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
assert pandasSQL.has_table("test_frame")
@@ -560,7 +575,7 @@ def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
@pytest.mark.parametrize("conn", all_connectable)
def test_to_sql_exist_fail(conn, test_frame1, request):
conn = request.getfixturevalue(conn)
with pandasSQL_builder(conn) as pandasSQL:
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
assert pandasSQL.has_table("test_frame")

@@ -595,9 +610,45 @@ def test_read_iris_query_chunksize(conn, request):
assert "SepalWidth" in iris_frame.columns


@pytest.mark.db
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
def test_read_iris_query_expression_with_parameter(conn, request):
conn = request.getfixturevalue(conn)
from sqlalchemy import (
MetaData,
Table,
create_engine,
select,
)

metadata = MetaData()
autoload_con = create_engine(conn) if isinstance(conn, str) else conn
iris = Table("iris", metadata, autoload_with=autoload_con)
iris_frame = read_sql_query(
select(iris), conn, params={"name": "Iris-setosa", "length": 5.1}
)
check_iris_frame(iris_frame)
if isinstance(conn, str):
autoload_con.dispose()


@pytest.mark.db
@pytest.mark.parametrize("conn", all_connectable_iris)
def test_read_iris_query_string_with_parameter(conn, request):
for db, query in SQL_STRINGS["read_parameters"].items():
if db in conn:
break
else:
raise KeyError(f"No part of {conn} found in SQL_STRINGS['read_parameters']")
conn = request.getfixturevalue(conn)
iris_frame = read_sql_query(query, conn, params=("Iris-setosa", 5.1))
check_iris_frame(iris_frame)


@pytest.mark.db
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
def test_read_iris_table(conn, request):
# GH 51015 if conn = sqlite_iris_str
conn = request.getfixturevalue(conn)
iris_frame = read_sql_table("iris", conn)
check_iris_frame(iris_frame)
@@ -627,7 +678,7 @@ def sample(pd_table, conn, keys, data_iter):
data = [dict(zip(keys, row)) for row in data_iter]
conn.execute(pd_table.table.insert(), data)

with pandasSQL_builder(conn) as pandasSQL:
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
pandasSQL.to_sql(test_frame1, "test_frame", method=sample)
assert pandasSQL.has_table("test_frame")
assert check == [1]
@@ -680,7 +731,8 @@ def test_read_procedure(conn, request):
with engine_conn.begin():
engine_conn.execute(proc)
else:
conn.execute(proc)
with conn.begin():
conn.execute(proc)

res1 = sql.read_sql_query("CALL get_testdb();", conn)
tm.assert_frame_equal(df, res1)
@@ -762,6 +814,8 @@ def teardown_method(self):
pass
else:
with conn:
for view in self._get_all_views(conn):
self.drop_view(view, conn)
for tbl in self._get_all_tables(conn):
self.drop_table(tbl, conn)

@@ -778,6 +832,14 @@ def _get_all_tables(self, conn):
c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
return [table[0] for table in c.fetchall()]

def drop_view(self, view_name, conn):
conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}")
conn.commit()

def _get_all_views(self, conn):
c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'")
return [view[0] for view in c.fetchall()]


class SQLAlchemyMixIn(MixInBase):
@classmethod
@@ -788,6 +850,8 @@ def connect(self):
return self.engine.connect()

def drop_table(self, table_name, conn):
if conn.in_transaction():
conn.get_transaction().rollback()
with conn.begin():
sql.SQLDatabase(conn).drop_table(table_name)

@@ -796,6 +860,20 @@ def _get_all_tables(self, conn):

return inspect(conn).get_table_names()

def drop_view(self, view_name, conn):
quoted_view = conn.engine.dialect.identifier_preparer.quote_identifier(
view_name
)
if conn.in_transaction():
conn.get_transaction().rollback()
with conn.begin():
conn.exec_driver_sql(f"DROP VIEW IF EXISTS {quoted_view}")

def _get_all_views(self, conn):
from sqlalchemy import inspect

return inspect(conn).get_view_names()


class PandasSQLTest:
"""
@@ -822,7 +900,7 @@ def load_types_data(self, types_data):

def _read_sql_iris_parameter(self):
query = SQL_STRINGS["read_parameters"][self.flavor]
params = ["Iris-setosa", 5.1]
params = ("Iris-setosa", 5.1)
iris_frame = self.pandasSQL.read_query(query, params=params)
check_iris_frame(iris_frame)

@@ -951,8 +1029,6 @@ class _TestSQLApi(PandasSQLTest):
@pytest.fixture(autouse=True)
def setup_method(self, iris_path, types_data):
self.conn = self.connect()
if not isinstance(self.conn, sqlite3.Connection):
self.conn.begin()
self.load_iris_data(iris_path)
self.load_types_data(types_data)
self.load_test_data_and_sql()
@@ -1448,7 +1524,8 @@ def test_not_reflect_all_tables(self):
with conn.begin():
conn.execute(query)
else:
self.conn.execute(query)
with self.conn.begin():
self.conn.execute(query)

with tm.assert_produces_warning(None):
sql.read_sql_table("other_table", self.conn)
@@ -1698,7 +1775,6 @@ def setup_class(cls):
def setup_method(self, iris_path, types_data):
try:
self.conn = self.engine.connect()
self.conn.begin()
self.pandasSQL = sql.SQLDatabase(self.conn)
except sqlalchemy.exc.OperationalError:
pytest.skip(f"Can't connect to {self.flavor} server")
@@ -1729,8 +1805,8 @@ def test_create_table(self):
temp_frame = DataFrame(
{"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
)
pandasSQL = sql.SQLDatabase(temp_conn)
assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4
with sql.SQLDatabase(temp_conn, need_transaction=True) as pandasSQL:
assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4

insp = inspect(temp_conn)
assert insp.has_table("temp_frame")
@@ -1749,6 +1825,10 @@ def test_drop_table(self):
assert insp.has_table("temp_frame")

pandasSQL.drop_table("temp_frame")
try:
insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior
except AttributeError:
pass
assert not insp.has_table("temp_frame")

def test_roundtrip(self, test_frame1):
@@ -2098,7 +2178,6 @@ def _get_index_columns(self, tbl_name):
def test_to_sql_save_index(self):
self._to_sql_save_index()

@pytest.mark.xfail(reason="Nested transactions rollbacks don't work with Pandas")
def test_transactions(self):
self._transaction_test()

@@ -2120,7 +2199,8 @@ def test_get_schema_create_table(self, test_frame3):
with conn.begin():
conn.execute(create_sql)
else:
self.conn.execute(create_sql)
with self.conn.begin():
self.conn.execute(create_sql)
returned_df = sql.read_sql_table(tbl, self.conn)
tm.assert_frame_equal(returned_df, blank_test_df, check_index_type=False)
self.drop_table(tbl, self.conn)
@@ -2586,7 +2666,8 @@ class Test(BaseModel):
id = Column(Integer, primary_key=True)
string_column = Column(String(50))

BaseModel.metadata.create_all(self.conn)
with self.conn.begin():
BaseModel.metadata.create_all(self.conn)
Session = sessionmaker(bind=self.conn)
with Session() as session:
df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]})
@@ -2680,8 +2761,9 @@ def test_schema_support(self):
df = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]})

# create a schema
self.conn.execute("DROP SCHEMA IF EXISTS other CASCADE;")
self.conn.execute("CREATE SCHEMA other;")
with self.conn.begin():
self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;")
self.conn.exec_driver_sql("CREATE SCHEMA other;")

# write dataframe to different schema's
assert df.to_sql("test_schema_public", self.conn, index=False) == 2
@@ -2713,8 +2795,9 @@ def test_schema_support(self):
# different if_exists options

# create a schema
self.conn.execute("DROP SCHEMA IF EXISTS other CASCADE;")
self.conn.execute("CREATE SCHEMA other;")
with self.conn.begin():
self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;")
self.conn.exec_driver_sql("CREATE SCHEMA other;")

# write dataframe with different if_exists options
assert (
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ python-snappy
pyxlsb
s3fs>=2021.08.0
scipy
sqlalchemy<1.4.46
sqlalchemy
tabulate
tzdata>=2022.1
xarray