Skip to content

Support CopyBoth queries and replication mode in config #778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker/sql_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ port = 5433
ssl = on
ssl_cert_file = 'server.crt'
ssl_key_file = 'server.key'
wal_level = logical
EOCONF

cat > "$PGDATA/pg_hba.conf" <<-EOCONF
Expand All @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject

# IPv4 local connections:
host all postgres 0.0.0.0/0 trust
host replication postgres 0.0.0.0/0 trust
# IPv6 local connections:
host all postgres ::0/0 trust
# Unix socket connections:
Expand Down
34 changes: 34 additions & 0 deletions postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D';
pub const ERROR_RESPONSE_TAG: u8 = b'E';
pub const COPY_IN_RESPONSE_TAG: u8 = b'G';
pub const COPY_OUT_RESPONSE_TAG: u8 = b'H';
pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
pub const BACKEND_KEY_DATA_TAG: u8 = b'K';
pub const NO_DATA_TAG: u8 = b'n';
Expand Down Expand Up @@ -93,6 +94,7 @@ pub enum Message {
CopyDone,
CopyInResponse(CopyInResponseBody),
CopyOutResponse(CopyOutResponseBody),
CopyBothResponse(CopyBothResponseBody),
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody),
Expand Down Expand Up @@ -190,6 +192,16 @@ impl Message {
storage,
})
}
COPY_BOTH_RESPONSE_TAG => {
let format = buf.read_u8()?;
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::CopyBothResponse(CopyBothResponseBody {
format,
len,
storage,
})
}
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
BACKEND_KEY_DATA_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
Expand Down Expand Up @@ -524,6 +536,28 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug, Clone)]
pub struct CopyBothResponseBody {
format: u8,
len: u16,
storage: Bytes,
}

impl CopyBothResponseBody {
#[inline]
pub fn format(&self) -> u8 {
self.format
}

#[inline]
pub fn column_formats(&self) -> ColumnFormats<'_> {
ColumnFormats {
remaining: self.len,
buf: &self.storage,
}
}
}

#[derive(Debug, Clone)]
pub struct DataRowBody {
storage: Bytes,
Expand Down
62 changes: 58 additions & 4 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::codec::BackendMessages;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::copy_both::{CopyBothDuplex, CopyBothReceiver};
use crate::copy_out::CopyOutStream;
#[cfg(feature = "runtime")]
use crate::keepalive::KeepaliveConfig;
Expand All @@ -13,13 +14,14 @@ use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken,
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
TransactionBuilder,
};
use bytes::{Buf, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_channel::mpsc;
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
use futures_util::{future, pin_mut, ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
use postgres_types::BorrowToSql;
Expand All @@ -29,6 +31,7 @@ use std::fmt;
use std::net::IpAddr;
#[cfg(feature = "runtime")]
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(feature = "runtime")]
Expand All @@ -40,6 +43,11 @@ pub struct Responses {
cur: BackendMessages,
}

pub struct CopyBothHandles {
pub(crate) stream_receiver: mpsc::Receiver<Result<Message, Error>>,
pub(crate) sink_sender: mpsc::Sender<FrontendMessage>,
}

impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
Expand All @@ -61,6 +69,17 @@ impl Responses {
}
}

impl Stream for Responses {
type Item = Result<Message, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!((*self).poll_next(cx)) {
Err(err) if err.is_closed() => Poll::Ready(None),
msg => Poll::Ready(Some(msg)),
}
}
}

/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare](prepare) module).
#[derive(Default)]
Expand Down Expand Up @@ -103,6 +122,32 @@ impl InnerClient {
})
}

pub fn start_copy_both(&self) -> Result<CopyBothHandles, Error> {
let (sender, receiver) = mpsc::channel(16);
let (stream_sender, stream_receiver) = mpsc::channel(16);
let (sink_sender, sink_receiver) = mpsc::channel(16);

let responses = Responses {
receiver,
cur: BackendMessages::empty(),
};
let messages = RequestMessages::CopyBoth(CopyBothReceiver::new(
responses,
sink_receiver,
stream_sender,
));

let request = Request { messages, sender };
self.sender
.unbounded_send(request)
.map_err(|_| Error::closed())?;

Ok(CopyBothHandles {
stream_receiver,
sink_sender,
})
}

pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
Expand Down Expand Up @@ -493,6 +538,15 @@ impl Client {
copy_out::copy_out(self.inner(), statement).await
}

/// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy
/// data.
pub async fn copy_both_simple<T>(&self, query: &str) -> Result<CopyBothDuplex<T>, Error>
where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the replication stream, if the timeline is historical, Postgres will send a tuple as a response. So we actually need a function that returns something like Result<(CopyBothDuplex<T>, Option<SimpleQueryMessage>), Error> (or maybe Result<(CopyBothDuplex<T>, Option<Vec<SimpleQueryMessage>>), Error> in case other commands are added in the future which use CopyBoth and return a set).

It's actually very specific to START_REPLICATION (and even more specifically, to physical replication), so it might make sense to have a more specific name or at least clarify what it's expecting the command to do. Maybe something like copy_both_simple_with_result()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I'll take a look on how we can expose this to users, ideally in a generic way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case you missed my other comment, there's a similar issue for BASE_BACKUP, except with CopyOut instead. That can be a separate PR, though.

T: Buf + 'static + Send,
{
copy_both::copy_both_simple(self.inner(), query).await
}

/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
Expand Down
45 changes: 45 additions & 0 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ pub enum LoadBalanceHosts {
Random,
}

/// Replication mode configuration.
///
/// It is recommended that you use a PostgreSQL server patch version
/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or
/// 9.5.25. Earlier patch levels have a bug that doesn't properly
/// handle pipelined requests after streaming has stopped.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReplicationMode {
/// Physical replication.
Physical,
/// Logical replication.
Logical,
}

/// A host specification.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Host {
Expand Down Expand Up @@ -209,6 +224,7 @@ pub struct Config {
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) load_balance_hosts: LoadBalanceHosts,
pub(crate) replication_mode: Option<ReplicationMode>,
}

impl Default for Config {
Expand Down Expand Up @@ -242,6 +258,7 @@ impl Config {
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
load_balance_hosts: LoadBalanceHosts::Disable,
replication_mode: None,
}
}

Expand Down Expand Up @@ -524,6 +541,22 @@ impl Config {
self.load_balance_hosts
}

/// Set replication mode.
///
/// It is recommended that you use a PostgreSQL server patch version
/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or
/// 9.5.25. Earlier patch levels have a bug that doesn't properly
/// handle pipelined requests after streaming has stopped.
pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
self.replication_mode = Some(replication_mode);
self
}

/// Get replication mode.
pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
self.replication_mode
}

fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
Expand Down Expand Up @@ -660,6 +693,17 @@ impl Config {
};
self.load_balance_hosts(load_balance_hosts);
}
"replication" => {
let mode = match value {
"off" => None,
"true" => Some(ReplicationMode::Physical),
"database" => Some(ReplicationMode::Logical),
_ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))),
};
if let Some(mode) = mode {
self.replication_mode(mode);
}
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
Expand Down Expand Up @@ -744,6 +788,7 @@ impl fmt::Debug for Config {
config_dbg
.field("target_session_attrs", &self.target_session_attrs)
.field("channel_binding", &self.channel_binding)
.field("replication", &self.replication_mode)
.finish()
}
}
Expand Down
8 changes: 7 additions & 1 deletion tokio-postgres/src/connect_raw.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::{self, Config};
use crate::config::{self, Config, ReplicationMode};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
Expand Down Expand Up @@ -133,6 +133,12 @@ where
if let Some(application_name) = &config.application_name {
params.push(("application_name", &**application_name));
}
if let Some(replication_mode) = &config.replication_mode {
match replication_mode {
ReplicationMode::Physical => params.push(("replication", "true")),
ReplicationMode::Logical => params.push(("replication", "database")),
}
}

let mut buf = BytesMut::new();
frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
Expand Down
20 changes: 20 additions & 0 deletions tokio-postgres/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::copy_both::CopyBothReceiver;
use crate::copy_in::CopyInReceiver;
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
Expand All @@ -20,6 +21,7 @@ use tokio_util::codec::Framed;
pub enum RequestMessages {
Single(FrontendMessage),
CopyIn(CopyInReceiver),
CopyBoth(CopyBothReceiver),
}

pub struct Request {
Expand Down Expand Up @@ -258,6 +260,24 @@ where
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyIn(receiver));
}
RequestMessages::CopyBoth(mut receiver) => {
let message = match receiver.poll_next_unpin(cx) {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => {
trace!("poll_write: finished copy_both request");
continue;
}
Poll::Pending => {
trace!("poll_write: waiting on copy_both stream");
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
return Ok(true);
}
};
Pin::new(&mut self.stream)
.start_send(message)
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
}
}
}
}
Expand Down
Loading
Loading