Skip to content

Support CopyBoth queries and replication mode in config #778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker/sql_setup.sh
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@ port = 5433
ssl = on
ssl_cert_file = 'server.crt'
ssl_key_file = 'server.key'
wal_level = logical
EOCONF

cat > "$PGDATA/pg_hba.conf" <<-EOCONF
@@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject

# IPv4 local connections:
host all postgres 0.0.0.0/0 trust
host replication postgres 0.0.0.0/0 trust
# IPv6 local connections:
host all postgres ::0/0 trust
# Unix socket connections:
34 changes: 34 additions & 0 deletions postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D';
pub const ERROR_RESPONSE_TAG: u8 = b'E';
pub const COPY_IN_RESPONSE_TAG: u8 = b'G';
pub const COPY_OUT_RESPONSE_TAG: u8 = b'H';
pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
pub const BACKEND_KEY_DATA_TAG: u8 = b'K';
pub const NO_DATA_TAG: u8 = b'n';
@@ -93,6 +94,7 @@ pub enum Message {
CopyDone,
CopyInResponse(CopyInResponseBody),
CopyOutResponse(CopyOutResponseBody),
CopyBothResponse(CopyBothResponseBody),
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody),
@@ -190,6 +192,16 @@ impl Message {
storage,
})
}
COPY_BOTH_RESPONSE_TAG => {
let format = buf.read_u8()?;
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::CopyBothResponse(CopyBothResponseBody {
format,
len,
storage,
})
}
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
BACKEND_KEY_DATA_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
@@ -524,6 +536,28 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug, Clone)]
pub struct CopyBothResponseBody {
format: u8,
len: u16,
storage: Bytes,
}

impl CopyBothResponseBody {
#[inline]
pub fn format(&self) -> u8 {
self.format
}

#[inline]
pub fn column_formats(&self) -> ColumnFormats<'_> {
ColumnFormats {
remaining: self.len,
buf: &self.storage,
}
}
}

#[derive(Debug, Clone)]
pub struct DataRowBody {
storage: Bytes,
62 changes: 58 additions & 4 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::codec::BackendMessages;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::copy_both::{CopyBothDuplex, CopyBothReceiver};
use crate::copy_out::CopyOutStream;
#[cfg(feature = "runtime")]
use crate::keepalive::KeepaliveConfig;
@@ -13,13 +14,14 @@ use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken,
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
TransactionBuilder,
};
use bytes::{Buf, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_channel::mpsc;
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
use futures_util::{future, pin_mut, ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
use postgres_types::BorrowToSql;
@@ -29,6 +31,7 @@ use std::fmt;
use std::net::IpAddr;
#[cfg(feature = "runtime")]
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(feature = "runtime")]
@@ -40,6 +43,11 @@ pub struct Responses {
cur: BackendMessages,
}

pub struct CopyBothHandles {
pub(crate) stream_receiver: mpsc::Receiver<Result<Message, Error>>,
pub(crate) sink_sender: mpsc::Sender<FrontendMessage>,
}

impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
@@ -61,6 +69,17 @@ impl Responses {
}
}

impl Stream for Responses {
type Item = Result<Message, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!((*self).poll_next(cx)) {
Err(err) if err.is_closed() => Poll::Ready(None),
msg => Poll::Ready(Some(msg)),
}
}
}

/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare](prepare) module).
#[derive(Default)]
@@ -103,6 +122,32 @@ impl InnerClient {
})
}

pub fn start_copy_both(&self) -> Result<CopyBothHandles, Error> {
let (sender, receiver) = mpsc::channel(16);
let (stream_sender, stream_receiver) = mpsc::channel(16);
let (sink_sender, sink_receiver) = mpsc::channel(16);

let responses = Responses {
receiver,
cur: BackendMessages::empty(),
};
let messages = RequestMessages::CopyBoth(CopyBothReceiver::new(
responses,
sink_receiver,
stream_sender,
));

let request = Request { messages, sender };
self.sender
.unbounded_send(request)
.map_err(|_| Error::closed())?;

Ok(CopyBothHandles {
stream_receiver,
sink_sender,
})
}

pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
@@ -493,6 +538,15 @@ impl Client {
copy_out::copy_out(self.inner(), statement).await
}

/// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy
/// data.
pub async fn copy_both_simple<T>(&self, query: &str) -> Result<CopyBothDuplex<T>, Error>
where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the replication stream, if the timeline is historical, Postgres will send a tuple as a response. So we actually need a function that returns something like Result<(CopyBothDuplex<T>, Option<SimpleQueryMessage>), Error> (or maybe Result<(CopyBothDuplex<T>, Option<Vec<SimpleQueryMessage>>), Error> in case other commands are added in the future which use CopyBoth and return a set).

It's actually very specific to START_REPLICATION (and even more specifically, to physical replication), so it might make sense to have a more specific name or at least clarify what it's expecting the command to do. Maybe something like copy_both_simple_with_result()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I'll take a look on how we can expose this to users, ideally in a generic way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case you missed my other comment, there's a similar issue for BASE_BACKUP, except with CopyOut instead. That can be a separate PR, though.

T: Buf + 'static + Send,
{
copy_both::copy_both_simple(self.inner(), query).await
}

/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
45 changes: 45 additions & 0 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
@@ -72,6 +72,21 @@ pub enum LoadBalanceHosts {
Random,
}

/// Replication mode configuration.
///
/// It is recommended that you use a PostgreSQL server patch version
/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or
/// 9.5.25. Earlier patch levels have a bug that doesn't properly
/// handle pipelined requests after streaming has stopped.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReplicationMode {
/// Physical replication.
Physical,
/// Logical replication.
Logical,
}

/// A host specification.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Host {
@@ -209,6 +224,7 @@ pub struct Config {
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) load_balance_hosts: LoadBalanceHosts,
pub(crate) replication_mode: Option<ReplicationMode>,
}

impl Default for Config {
@@ -242,6 +258,7 @@ impl Config {
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
load_balance_hosts: LoadBalanceHosts::Disable,
replication_mode: None,
}
}

@@ -524,6 +541,22 @@ impl Config {
self.load_balance_hosts
}

/// Set replication mode.
///
/// It is recommended that you use a PostgreSQL server patch version
/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or
/// 9.5.25. Earlier patch levels have a bug that doesn't properly
/// handle pipelined requests after streaming has stopped.
pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
self.replication_mode = Some(replication_mode);
self
}

/// Get replication mode.
pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
self.replication_mode
}

fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
@@ -660,6 +693,17 @@ impl Config {
};
self.load_balance_hosts(load_balance_hosts);
}
"replication" => {
let mode = match value {
"off" => None,
"true" => Some(ReplicationMode::Physical),
"database" => Some(ReplicationMode::Logical),
_ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))),
};
if let Some(mode) = mode {
self.replication_mode(mode);
}
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
@@ -744,6 +788,7 @@ impl fmt::Debug for Config {
config_dbg
.field("target_session_attrs", &self.target_session_attrs)
.field("channel_binding", &self.channel_binding)
.field("replication", &self.replication_mode)
.finish()
}
}
8 changes: 7 additions & 1 deletion tokio-postgres/src/connect_raw.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::{self, Config};
use crate::config::{self, Config, ReplicationMode};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
@@ -133,6 +133,12 @@ where
if let Some(application_name) = &config.application_name {
params.push(("application_name", &**application_name));
}
if let Some(replication_mode) = &config.replication_mode {
match replication_mode {
ReplicationMode::Physical => params.push(("replication", "true")),
ReplicationMode::Logical => params.push(("replication", "database")),
}
}

let mut buf = BytesMut::new();
frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
20 changes: 20 additions & 0 deletions tokio-postgres/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::copy_both::CopyBothReceiver;
use crate::copy_in::CopyInReceiver;
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
@@ -20,6 +21,7 @@ use tokio_util::codec::Framed;
pub enum RequestMessages {
Single(FrontendMessage),
CopyIn(CopyInReceiver),
CopyBoth(CopyBothReceiver),
}

pub struct Request {
@@ -258,6 +260,24 @@ where
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyIn(receiver));
}
RequestMessages::CopyBoth(mut receiver) => {
let message = match receiver.poll_next_unpin(cx) {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => {
trace!("poll_write: finished copy_both request");
continue;
}
Poll::Pending => {
trace!("poll_write: waiting on copy_both stream");
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
return Ok(true);
}
};
Pin::new(&mut self.stream)
.start_send(message)
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
}
}
}
}
358 changes: 358 additions & 0 deletions tokio-postgres/src/copy_both.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::{simple_query, Error};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_channel::mpsc;
use futures_util::{ready, Sink, SinkExt, Stream, StreamExt};
use log::debug;
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::CopyData;
use std::marker::{PhantomData, PhantomPinned};
use std::pin::Pin;
use std::task::{Context, Poll};

/// The state machine of CopyBothReceiver
///
/// ```ignore
/// CopyBoth
/// / \
/// v v
/// CopyOut CopyIn
/// \ /
/// v v
/// CopyNone
/// |
/// v
/// CopyComplete
/// |
/// v
/// CommandComplete
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CopyBothState {
/// The state before having entered the CopyBoth mode.
Setup,
/// Initial state where CopyData messages can go in both directions
CopyBoth,
/// The server->client stream is closed and we're in CopyIn mode
CopyIn,
/// The client->server stream is closed and we're in CopyOut mode
CopyOut,
/// Both directions are closed, we waiting for CommandComplete messages
CopyNone,
/// We have received the first CommandComplete message for the copy
CopyComplete,
/// We have received the final CommandComplete message for the statement
CommandComplete,
}

/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no
/// matter what the users do with their CopyBothDuplex handle we're always going to send the
/// correct messages to the backend in order to restore the connection into a usable state.
///
/// ```ignore
/// |
/// <tokio_postgres owned> | <userland owned>
/// |
/// pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex
/// | ^ \
/// | / v
/// | Sink Stream
/// ```
pub struct CopyBothReceiver {
/// Receiver of backend messages from the underlying [Connection](crate::Connection)
responses: Responses,
/// Receiver of frontend messages sent by the user using <CopyBothDuplex as Sink>
sink_receiver: mpsc::Receiver<FrontendMessage>,
/// Sender of CopyData contents to be consumed by the user using <CopyBothDuplex as Stream>
stream_sender: mpsc::Sender<Result<Message, Error>>,
/// The current state of the subprotocol
state: CopyBothState,
/// Holds a buffered message until we are ready to send it to the user's stream
buffered_message: Option<Result<Message, Error>>,
}

impl CopyBothReceiver {
pub(crate) fn new(
responses: Responses,
sink_receiver: mpsc::Receiver<FrontendMessage>,
stream_sender: mpsc::Sender<Result<Message, Error>>,
) -> CopyBothReceiver {
CopyBothReceiver {
responses,
sink_receiver,
stream_sender,
state: CopyBothState::Setup,
buffered_message: None,
}
}

/// Convenience method to set the subprotocol into an unexpected message state
fn unexpected_message(&mut self) {
self.sink_receiver.close();
self.buffered_message = Some(Err(Error::unexpected_message()));
self.state = CopyBothState::CommandComplete;
}

/// Processes messages from the backend, it will resolve once all backend messages have been
/// processed
fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> {
use CopyBothState::*;

loop {
// Deliver the buffered message (if any) to the user to ensure we can potentially
// buffer a new one in response to a server message
if let Some(message) = self.buffered_message.take() {
match self.stream_sender.poll_ready(cx) {
Poll::Ready(_) => {
// If the receiver has hung up we'll just drop the message
let _ = self.stream_sender.start_send(message);
}
Poll::Pending => {
// Stash the message and try again later
self.buffered_message = Some(message);
return Poll::Pending;
}
}
}

match ready!(self.responses.poll_next_unpin(cx)) {
Some(Ok(Message::CopyBothResponse(body))) => match self.state {
Setup => {
self.buffered_message = Some(Ok(Message::CopyBothResponse(body)));
self.state = CopyBoth;
}
_ => self.unexpected_message(),
},
Some(Ok(Message::CopyData(body))) => match self.state {
CopyBoth | CopyOut => {
self.buffered_message = Some(Ok(Message::CopyData(body)));
}
_ => self.unexpected_message(),
},
// The server->client stream is done
Some(Ok(Message::CopyDone)) => {
match self.state {
CopyBoth => self.state = CopyIn,
CopyOut => self.state = CopyNone,
_ => self.unexpected_message(),
};
}
Some(Ok(Message::CommandComplete(_))) => {
match self.state {
CopyNone => self.state = CopyComplete,
CopyComplete => {
self.stream_sender.close_channel();
self.sink_receiver.close();
self.state = CommandComplete;
}
_ => self.unexpected_message(),
};
}
// The server indicated an error, terminate our side if we haven't already
Some(Err(err)) => {
match self.state {
Setup | CopyBoth | CopyOut | CopyIn => {
self.sink_receiver.close();
self.buffered_message = Some(Err(err));
self.state = CommandComplete;
}
_ => self.unexpected_message(),
};
}
Some(Ok(Message::ReadyForQuery(_))) => match self.state {
CommandComplete => {
self.sink_receiver.close();
self.stream_sender.close_channel();
}
_ => self.unexpected_message(),
},
Some(Ok(_)) => self.unexpected_message(),
None => return Poll::Ready(()),
}
}
}
}

/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This
/// is the mechanism that drives the CopyBoth subprotocol forward
impl Stream for CopyBothReceiver {
type Item = FrontendMessage;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
use CopyBothState::*;

match self.poll_backend(cx) {
Poll::Ready(()) => Poll::Ready(None),
Poll::Pending => match self.state {
Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) {
Some(msg) => Poll::Ready(Some(msg)),
None => {
self.state = match self.state {
CopyBoth => CopyOut,
CopyIn => CopyNone,
_ => unreachable!(),
};

let mut buf = BytesMut::new();
frontend::copy_done(&mut buf);
Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
}
},
_ => Poll::Pending,
},
}
}
}

pin_project! {
/// A duplex stream for consuming streaming replication data.
///
/// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new
/// query. This will ensure that the connection returns into normal processing mode.
///
/// ```no_run
/// use tokio_postgres::Client;
///
/// async fn foo(client: &Client) {
/// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await;
///
/// // ⚠️ INCORRECT ⚠️
/// client.query("SELECT 1", &[]).await; // hangs forever
///
/// // duplex_stream drop-ed here
/// }
/// ```
///
/// ```no_run
/// use tokio_postgres::Client;
///
/// async fn foo(client: &Client) {
/// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await;
///
/// // ✅ CORRECT ✅
/// drop(duplex_stream);
///
/// client.query("SELECT 1", &[]).await;
/// }
/// ```
pub struct CopyBothDuplex<T> {
#[pin]
sink_sender: mpsc::Sender<FrontendMessage>,
#[pin]
stream_receiver: mpsc::Receiver<Result<Message, Error>>,
buf: BytesMut,
#[pin]
_p: PhantomPinned,
_p2: PhantomData<T>,
}
}

impl<T> Stream for CopyBothDuplex<T> {
type Item = Result<Bytes, Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) {
Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())),
Some(Ok(_)) => Some(Err(Error::unexpected_message())),
Some(Err(err)) => Some(Err(err)),
None => None,
})
}
}

impl<T> Sink<T> for CopyBothDuplex<T>
where
T: Buf + 'static + Send,
{
type Error = Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project()
.sink_sender
.poll_ready(cx)
.map_err(|_| Error::closed())
}

fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
let this = self.project();

let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
if this.buf.is_empty() {
Box::new(item)
} else {
Box::new(this.buf.split().freeze().chain(item))
}
} else {
this.buf.put(item);
if this.buf.len() > 4096 {
Box::new(this.buf.split().freeze())
} else {
return Ok(());
}
};

let data = CopyData::new(data).map_err(Error::encode)?;
this.sink_sender
.start_send(FrontendMessage::CopyData(data))
.map_err(|_| Error::closed())
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let mut this = self.project();

if !this.buf.is_empty() {
ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
let data = CopyData::new(data).map_err(Error::encode)?;
this.sink_sender
.as_mut()
.start_send(FrontendMessage::CopyData(data))
.map_err(|_| Error::closed())?;
}

this.sink_sender.poll_flush(cx).map_err(|_| Error::closed())
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
ready!(self.as_mut().poll_flush(cx))?;
let mut this = self.as_mut().project();
this.sink_sender.disconnect();
Poll::Ready(Ok(()))
}
}

pub async fn copy_both_simple<T>(
client: &InnerClient,
query: &str,
) -> Result<CopyBothDuplex<T>, Error>
where
T: Buf + 'static + Send,
{
debug!("executing copy both query {}", query);

let buf = simple_query::encode(client, query)?;

let mut handles = client.start_copy_both()?;

handles
.sink_sender
.send(FrontendMessage::Raw(buf))
.await
.map_err(|_| Error::closed())?;

match handles.stream_receiver.next().await.transpose()? {
Some(Message::CopyBothResponse(_)) => {}
_ => return Err(Error::unexpected_message()),
}

Ok(CopyBothDuplex {
stream_receiver: handles.stream_receiver,
sink_sender: handles.sink_sender,
buf: BytesMut::new(),
_p: PhantomPinned,
_p2: PhantomData,
})
}
2 changes: 2 additions & 0 deletions tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
@@ -123,6 +123,7 @@ pub use crate::cancel_token::CancelToken;
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::connection::Connection;
pub use crate::copy_both::CopyBothDuplex;
pub use crate::copy_in::CopyInSink;
pub use crate::copy_out::CopyOutStream;
use crate::error::DbError;
@@ -160,6 +161,7 @@ mod connect_raw;
mod connect_socket;
mod connect_tls;
mod connection;
mod copy_both;
mod copy_in;
mod copy_out;
pub mod error;
2 changes: 1 addition & 1 deletion tokio-postgres/src/simple_query.rs
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro
}
}

fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
Ok(buf.split().freeze())
125 changes: 125 additions & 0 deletions tokio-postgres/tests/test/copy_both.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use futures_util::{future, StreamExt, TryStreamExt};
use tokio_postgres::{error::SqlState, Client, SimpleQueryMessage, SimpleQueryRow};

async fn q(client: &Client, query: &str) -> Vec<SimpleQueryRow> {
let msgs = client.simple_query(query).await.unwrap();

msgs.into_iter()
.filter_map(|msg| match msg {
SimpleQueryMessage::Row(row) => Some(row),
_ => None,
})
.collect()
}

#[tokio::test]
async fn copy_both_error() {
let client = crate::connect("user=postgres replication=database").await;

let err = client
.copy_both_simple::<bytes::Bytes>("START_REPLICATION SLOT undefined LOGICAL 0000/0000")
.await
.err()
.unwrap();

assert_eq!(err.code(), Some(&SqlState::UNDEFINED_OBJECT));

// Ensure we can continue issuing queries
assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1"));
}

#[tokio::test]
async fn copy_both_stream_error() {
let client = crate::connect("user=postgres replication=true").await;

q(&client, "CREATE_REPLICATION_SLOT err2 PHYSICAL").await;

// This will immediately error out after entering CopyBoth mode
let duplex_stream = client
.copy_both_simple::<bytes::Bytes>("START_REPLICATION SLOT err2 PHYSICAL FFFF/FFFF")
.await
.unwrap();

let mut msgs: Vec<_> = duplex_stream.collect().await;
let result = msgs.pop().unwrap();
assert_eq!(msgs.len(), 0);
assert!(result.unwrap_err().as_db_error().is_some());

// Ensure we can continue issuing queries
assert_eq!(q(&client, "DROP_REPLICATION_SLOT err2").await.len(), 0);
}

#[tokio::test]
async fn copy_both_stream_error_sync() {
let client = crate::connect("user=postgres replication=database").await;

q(&client, "CREATE_REPLICATION_SLOT err1 TEMPORARY PHYSICAL").await;

// This will immediately error out after entering CopyBoth mode
let duplex_stream = client
.copy_both_simple::<bytes::Bytes>("START_REPLICATION SLOT err1 PHYSICAL FFFF/FFFF")
.await
.unwrap();

// Immediately close our sink to send a CopyDone before receiving the ErrorResponse
drop(duplex_stream);

// Ensure we can continue issuing queries
assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1"));
}

#[tokio::test]
async fn copy_both() {
let client = crate::connect("user=postgres replication=database").await;

q(&client, "DROP TABLE IF EXISTS replication").await;
q(&client, "CREATE TABLE replication (i text)").await;

let slot_query = "CREATE_REPLICATION_SLOT slot TEMPORARY LOGICAL \"test_decoding\"";
let lsn = q(&client, slot_query).await[0]
.get("consistent_point")
.unwrap()
.to_owned();

// We will attempt to read this from the other end
q(&client, "BEGIN").await;
let xid = q(&client, "SELECT txid_current()").await[0]
.get("txid_current")
.unwrap()
.to_owned();
q(&client, "INSERT INTO replication VALUES ('processed')").await;
q(&client, "COMMIT").await;

// Insert a second row to generate unprocessed messages in the stream
q(&client, "INSERT INTO replication VALUES ('ignored')").await;

let query = format!("START_REPLICATION SLOT slot LOGICAL {}", lsn);
let duplex_stream = client
.copy_both_simple::<bytes::Bytes>(&query)
.await
.unwrap();

let expected = vec![
format!("BEGIN {}", xid),
"table public.replication: INSERT: i[text]:'processed'".to_string(),
format!("COMMIT {}", xid),
];

let actual: Vec<_> = duplex_stream
// Process only XLogData messages
.try_filter(|buf| future::ready(buf[0] == b'w'))
// Playback the stream until the first expected message
.try_skip_while(|buf| future::ready(Ok(!buf.ends_with(expected[0].as_ref()))))
// Take only the expected number of messsage, the rest will be discarded by tokio_postgres
.take(expected.len())
.try_collect()
.await
.unwrap();

for (msg, ending) in actual.into_iter().zip(expected.into_iter()) {
assert!(msg.ends_with(ending.as_ref()));
}

// Ensure we can continue issuing queries
assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1"));
}
1 change: 1 addition & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ use tokio_postgres::{
};

mod binary_copy;
mod copy_both;
mod parse;
#[cfg(feature = "runtime")]
mod runtime;