Skip to content

Commit c9347e2

Browse files
committed
honour connect_timeout and PGCONNECT_TIMEOUT
The wait_select callback previously installed to enable Ctrl+C during long queries was breaking configurable connection timeout : psycopg/psycopg2#944 This is replaced with a more visible async connection and a manual call to a custom wait_select with support for timeout. The timeout mimics default libpq behavior and reads the connect_timeout connection parameter with a fallback on PGCONNECT_TIMEOUT environment variable (and a default of 0: no timeout). A secondary benefit is to allow importing PGMigrate inside another project without PGMigrate altering the global set_wait_callback.
1 parent 8e3a4db commit c9347e2

File tree

5 files changed

+121
-6
lines changed

5 files changed

+121
-6
lines changed

aiven_db_migrate/migrate/pgmigrate.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
PGDataDumpFailedError, PGDataNotFoundError, PGMigrateValidationFailedError, PGSchemaDumpFailedError, PGTooMuchDataError
55
)
66
from aiven_db_migrate.migrate.pgutils import (
7-
create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length
7+
create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length, wait_select
88
)
99
from aiven_db_migrate.migrate.version import __version__
1010
from concurrent import futures
@@ -31,8 +31,6 @@
3131
import threading
3232
import time
3333

34-
# https://www.psycopg.org/docs/faq.html#faq-interrupt-query
35-
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
3634
MAX_CLI_LEN = 2097152 # getconf ARG_MAX
3735

3836

@@ -136,6 +134,12 @@ def conn_str(self, *, dbname: str = None) -> str:
136134
conn_info["application_name"] = conn_info["application_name"] + "/" + self.mangle_db_name(conn_info["dbname"])
137135
return create_connection_string(conn_info)
138136

137+
def connect_timeout(self):
138+
try:
139+
return int(self.conn_info.get("connect_timeout", os.environ.get("PGCONNECT_TIMEOUT", "")), 10)
140+
except ValueError:
141+
return None
142+
139143
@contextmanager
140144
def _cursor(self, *, dbname: str = None) -> RealDictCursor:
141145
conn: psycopg2.extensions.connection = None
@@ -146,8 +150,8 @@ def _cursor(self, *, dbname: str = None) -> RealDictCursor:
146150
# from multiple threads; allow only one connection at time
147151
self.conn_lock.acquire()
148152
try:
149-
conn = psycopg2.connect(**conn_info)
150-
conn.autocommit = True
153+
conn = psycopg2.connect(**conn_info, async_=True)
154+
wait_select(conn, self.connect_timeout())
151155
yield conn.cursor(cursor_factory=RealDictCursor)
152156
finally:
153157
if conn is not None:
@@ -165,7 +169,15 @@ def c(
165169
) -> List[Dict[str, Any]]:
166170
results: List[Dict[str, Any]] = []
167171
with self._cursor(dbname=dbname) as cur:
168-
cur.execute(query, args)
172+
try:
173+
cur.execute(query, args)
174+
wait_select(cur.connection)
175+
except KeyboardInterrupt:
176+
# We wrap the whole execute+wait block to make sure we cancel
177+
# the query in all cases, which we couldn't if KeyboardInterupt
178+
# was only handled inside wait_select.
179+
cur.connection.cancel()
180+
raise
169181
if return_rows:
170182
results = cur.fetchall()
171183
if return_rows > 0 and len(results) != return_rows:

aiven_db_migrate/migrate/pgutils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from typing import Any, Dict
55
from urllib.parse import parse_qs, urlparse
66

7+
import psycopg2
8+
import select
9+
import time
10+
711

812
def find_pgbin_dir(pgversion: str) -> Path:
913
def _pgbin_paths():
@@ -105,3 +109,37 @@ def parse_connection_string_url(url: str) -> Dict[str, str]:
105109
for k, v in parse_qs(p.query).items():
106110
fields[k] = v[-1]
107111
return fields
112+
113+
114+
# This enables interruptible queries with an approach similar to
115+
# https://www.psycopg.org/docs/faq.html#faq-interrupt-query
116+
# However, to handle timeouts we can't use psycopg2.extensions.set_wait_callback :
117+
# https://github.com/psycopg/psycopg2/issues/944
118+
# Instead we rely on manually calling wait_select after connection and queries.
119+
# Since it's not a wait callback, we do not capture and transform KeyboardInterupt here.
120+
def wait_select(conn, timeout=None):
121+
start_time = time.monotonic()
122+
poll = select.poll()
123+
while True:
124+
if timeout is not None and timeout > 0:
125+
time_left = start_time + timeout - time.monotonic()
126+
if time_left <= 0:
127+
raise TimeoutError("wait_select: timeout after {} seconds".format(timeout))
128+
else:
129+
time_left = 1
130+
state = conn.poll()
131+
if state == psycopg2.extensions.POLL_OK:
132+
return
133+
elif state == psycopg2.extensions.POLL_READ:
134+
poll.register(conn.fileno(), select.POLLIN)
135+
elif state == psycopg2.extensions.POLL_WRITE:
136+
poll.register(conn.fileno(), select.POLLOUT)
137+
else:
138+
raise conn.OperationalError("wait_select: invalid poll state")
139+
try:
140+
# When the remote address does not exist at all, poll.poll() waits its full timeout without any event.
141+
# However, in the same conditions, conn.poll() raises a psycopg2 exception almost immediately.
142+
# It is better to fail quickly instead of waiting the full timeout, so we keep our poll.poll() below 1sec.
143+
poll.poll(min(1.0, time_left) * 1000)
144+
finally:
145+
poll.unregister(conn.fileno())

test/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def inject_pg_fixture(*, name: str, pgversion: str, scope="module"):
402402

403403

404404
SUPPORTED_PG_VERSIONS = ["9.5", "9.6", "10", "11", "12"]
405+
pg_cluster_for_tests: List[str] = list()
405406
pg_source_and_target_for_tests: List[Tuple[str, str]] = list()
406407
pg_source_and_target_for_replication_tests: List[Tuple[str, str]] = list()
407408

@@ -437,6 +438,10 @@ def generate_fixtures():
437438
pg_source_and_target_for_tests.append((source_name, target_name))
438439
if LooseVersion(source) >= "10":
439440
pg_source_and_target_for_replication_tests.append((source_name, target_name))
441+
for version in set(pg_source_versions).union(pg_target_versions):
442+
fixture_name = "pg{}".format(version.replace(".", ""))
443+
inject_pg_fixture(name=fixture_name, pgversion=version)
444+
pg_cluster_for_tests.append(fixture_name)
440445

441446

442447
generate_fixtures()
@@ -450,6 +455,17 @@ def test_pg_source_and_target_for_replication_tests():
450455
print(pg_source_and_target_for_replication_tests)
451456

452457

458+
@pytest.fixture(name="pg_cluster", params=pg_cluster_for_tests, scope="function")
459+
def fixture_pg_cluster(request):
460+
"""Returns a fixture parametrized on the union of all source and target pg versions."""
461+
cluster_runner = request.getfixturevalue(request.param)
462+
yield cluster_runner
463+
for cleanup in cluster_runner.cleanups:
464+
cleanup()
465+
cluster_runner.cleanups.clear()
466+
cluster_runner.drop_dbs()
467+
468+
453469
@pytest.fixture(name="pg_source_and_target", params=pg_source_and_target_for_tests, scope="function")
454470
def fixture_pg_source_and_target(request):
455471
source, target = request.param

test/test_pg_cluster.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2021 Aiven, Helsinki, Finland. https://aiven.io/
2+
import signal
3+
4+
from aiven_db_migrate.migrate.pgmigrate import PGCluster
5+
from multiprocessing import Process
6+
from test.conftest import PGRunner
7+
from typing import Tuple
8+
9+
import os
10+
import pytest
11+
import time
12+
13+
14+
def test_interruptible_queries(pg_cluster: PGRunner):
15+
def wait_and_interrupt():
16+
time.sleep(1)
17+
os.kill(os.getppid(), signal.SIGINT)
18+
19+
cluster = PGCluster(conn_info=pg_cluster.conn_info())
20+
interuptor = Process(target=wait_and_interrupt)
21+
interuptor.start()
22+
start_time = time.monotonic()
23+
with pytest.raises(KeyboardInterrupt):
24+
cluster.c("select pg_sleep(100)")
25+
assert time.monotonic() - start_time < 2
26+
interuptor.join()

test/test_pg_migrate.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from test.utils import random_string, Timer
77
from typing import Any, Dict, Optional
88

9+
import os
910
import psycopg2
1011
import pytest
12+
import time
1113

1214

1315
class PGMigrateTest:
@@ -154,6 +156,27 @@ def test_migrate_invalid_conn_str(self):
154156
PGMigrate(source_conn_info=source_conn_info, target_conn_info=target_conn_info).migrate()
155157
assert str(err.value) == "Invalid source or target connection string"
156158

159+
def test_migrate_connect_timeout_parameter(self):
160+
for source_conn_info in ("host=example.org connect_timeout=1", "postgresql://example.org?connect_timeout=1"):
161+
start_time = time.monotonic()
162+
with pytest.raises(TimeoutError):
163+
PGMigrate(source_conn_info=source_conn_info, target_conn_info=self.target.conn_info()).migrate()
164+
end_time = time.monotonic()
165+
assert end_time - start_time < 2
166+
167+
def test_migrate_connect_timeout_environment(self):
168+
start_time = time.monotonic()
169+
original_timeout = os.environ.get("PGCONNECT_TIMEOUT")
170+
try:
171+
os.environ["PGCONNECT_TIMEOUT"] = "1"
172+
with pytest.raises(TimeoutError):
173+
PGMigrate(source_conn_info="host=example.org", target_conn_info=self.target.conn_info()).migrate()
174+
end_time = time.monotonic()
175+
assert end_time - start_time < 2
176+
finally:
177+
if original_timeout is not None:
178+
os.environ["PGCONNECT_TIMEOUT"] = original_timeout
179+
157180
def test_migrate_same_server(self):
158181
source_conn_info = target_conn_info = self.target.conn_info()
159182
with pytest.raises(PGMigrateValidationFailedError) as err:

0 commit comments

Comments
 (0)