Skip to content

Commit 6ed9850

Browse files
feat: add overwrite_method to postgresql.to_sql (#2820)
1 parent 44ae3fb commit 6ed9850

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

awswrangler/postgresql.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,22 @@ def _validate_connection(con: "pg8000.Connection") -> None:
4242
)
4343

4444

45-
def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str) -> None:
45+
def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
4646
schema_str = f"{_identifier(schema)}." if schema else ""
47-
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)}"
47+
cascade_str = "CASCADE" if cascade else "RESTRICT"
48+
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)} {cascade_str}"
4849
_logger.debug("Drop table query:\n%s", sql)
4950
cursor.execute(sql)
5051

5152

53+
def _truncate_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
54+
schema_str = f"{_identifier(schema)}." if schema else ""
55+
cascade_str = "CASCADE" if cascade else "RESTRICT"
56+
sql = f"TRUNCATE TABLE {schema_str}{_identifier(table)} {cascade_str}"
57+
_logger.debug("Truncate table query:\n%s", sql)
58+
cursor.execute(sql)
59+
60+
5261
def _does_table_exist(cursor: "pg8000.Cursor", schema: str | None, table: str) -> bool:
5362
schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else ""
5463
cursor.execute(
@@ -66,12 +75,21 @@ def _create_table(
6675
table: str,
6776
schema: str,
6877
mode: str,
78+
overwrite_method: _ToSqlOverwriteModeLiteral,
6979
index: bool,
7080
dtype: dict[str, str] | None,
7181
varchar_lengths: dict[str, int] | None,
7282
) -> None:
7383
if mode == "overwrite":
74-
_drop_table(cursor=cursor, schema=schema, table=table)
84+
if overwrite_method in ["drop", "cascade"]:
85+
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
86+
elif overwrite_method in ["truncate", "truncate cascade"]:
87+
if _does_table_exist(cursor=cursor, schema=schema, table=table):
88+
_truncate_table(
89+
cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "truncate cascade")
90+
)
91+
else:
92+
raise exceptions.InvalidArgumentValue(f"Invalid overwrite_method: {overwrite_method}")
7593
elif _does_table_exist(cursor=cursor, schema=schema, table=table):
7694
return
7795
postgresql_types: dict[str, str] = _data_types.database_types_from_pandas(
@@ -485,6 +503,7 @@ def read_sql_table(
485503

486504

487505
_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
506+
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "truncate cascade"]
488507

489508

490509
@_utils.check_optional_dependency(pg8000, "pg8000")
@@ -495,6 +514,7 @@ def to_sql(
495514
table: str,
496515
schema: str,
497516
mode: _ToSqlModeLiteral = "append",
517+
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
498518
index: bool = False,
499519
dtype: dict[str, str] | None = None,
500520
varchar_lengths: dict[str, int] | None = None,
@@ -522,6 +542,13 @@ def to_sql(
522542
overwrite: Drops table and recreates.
523543
upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and
524544
sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode.
545+
overwrite_method : str
546+
Drop, cascade, truncate, or truncate cascade. Only applicable in overwrite mode.
547+
548+
"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
549+
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
550+
"truncate" - ``TRUNCATE ... RESTRICT`` - truncates the table. Fails if any of the tables have foreign-key references from tables that are not listed in the command.
551+
"truncate cascade" - ``TRUNCATE ... CASCADE`` - truncates the table, and all tables that have foreign-key references to any of the named tables.
525552
index : bool
526553
True to store the DataFrame index as a column in the table,
527554
otherwise False to ignore it.
@@ -583,6 +610,7 @@ def to_sql(
583610
table=table,
584611
schema=schema,
585612
mode=mode,
613+
overwrite_method=overwrite_method,
586614
index=index,
587615
dtype=dtype,
588616
varchar_lengths=varchar_lengths,

tests/unit/test_postgresql.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,51 @@ def test_read_sql_query_simple(databases_parameters):
4949

5050
def test_to_sql_simple(postgresql_table, postgresql_con):
5151
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
52-
wr.postgresql.to_sql(df, postgresql_con, postgresql_table, "public", "overwrite", True)
52+
wr.postgresql.to_sql(
53+
df=df,
54+
con=postgresql_con,
55+
table=postgresql_table,
56+
schema="public",
57+
mode="overwrite",
58+
index=True,
59+
)
60+
61+
62+
@pytest.mark.parametrize("overwrite_method", ["drop", "cascade", "truncate", "truncate cascade"])
63+
def test_to_sql_overwrite(postgresql_table, postgresql_con, overwrite_method):
64+
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
65+
wr.postgresql.to_sql(
66+
df=df,
67+
con=postgresql_con,
68+
table=postgresql_table,
69+
schema="public",
70+
mode="overwrite",
71+
overwrite_method=overwrite_method,
72+
)
73+
df = pd.DataFrame({"c0": [4, 5, 6], "c1": ["xoo", "yoo", "zoo"]})
74+
wr.postgresql.to_sql(
75+
df=df,
76+
con=postgresql_con,
77+
table=postgresql_table,
78+
schema="public",
79+
mode="overwrite",
80+
overwrite_method=overwrite_method,
81+
)
82+
df = wr.postgresql.read_sql_table(table=postgresql_table, schema="public", con=postgresql_con)
83+
assert df.shape == (3, 2)
84+
85+
86+
def test_unknown_overwrite_method_error(postgresql_table, postgresql_con):
87+
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
88+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
89+
wr.postgresql.to_sql(
90+
df=df,
91+
con=postgresql_con,
92+
table=postgresql_table,
93+
schema="public",
94+
mode="overwrite",
95+
overwrite_method="unknown",
96+
)
5397

5498

5599
def test_sql_types(postgresql_table, postgresql_con):

0 commit comments

Comments
 (0)