Skip to content

Work with pools that don't support prepared statements #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,88 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Like `query`, but requires the types of query parameters to be explicitly specified.
///
/// Compared to `query`, this method allows performing queries without three round trips (for prepare, execute, and close). Thus,
/// this is suitable in environments where prepared statements aren't supported (such as Cloudflare Workers with Hyperdrive).
///
/// # Examples
///
/// ```no_run
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
/// use tokio_postgres::types::ToSql;
/// use tokio_postgres::types::Type;
/// use futures_util::{pin_mut, TryStreamExt};
///
/// let rows = client.query_with_param_types(
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
/// ).await?;
///
/// for row in rows {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # Ok(())
/// # }
/// ```
pub async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_raw_with_param_types(statement, params)
.await?
.try_collect()
.await
}

/// The maximally flexible version of [`query_with_param_types`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The parameters must specify value along with their Postgres type. This allows performing
/// queries without three round trips (for prepare, execute, and close).
///
/// [`query_with_param_types`]: #method.query_with_param_types
///
/// # Examples
///
/// ```no_run
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
/// use tokio_postgres::types::ToSql;
/// use tokio_postgres::types::Type;
/// use futures_util::{pin_mut, TryStreamExt};
///
/// let mut it = client.query_raw_with_param_types(
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
/// ).await?;
///
/// pin_mut!(it);
/// while let Some(row) = it.try_next().await? {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # Ok(())
/// # }
/// ```
pub async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error> {
fn slice_iter<'a>(
s: &'a [(&'a (dyn ToSql + Sync), Type)],
) -> impl ExactSizeIterator<Item = (&'a dyn ToSql, Type)> + 'a {
s.iter()
.map(|(param, param_type)| (*param as _, param_type.clone()))
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to factor this into a separate function here since it's only being called one place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Earlier thought was to allow access to the raw RowStream.


query::query_with_param_types(&self.inner, statement, slice_iter(params)).await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
46 changes: 46 additions & 0 deletions tokio-postgres/src/generic_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ pub trait GenericClient: private::Sealed {
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator;

/// Like `Client::query_with_param_types`
async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error>;

/// Like `Client::query_raw_with_param_types`.
async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error>;

/// Like `Client::prepare`.
async fn prepare(&self, query: &str) -> Result<Statement, Error>;

Expand Down Expand Up @@ -136,6 +150,22 @@ impl GenericClient for Client {
self.query_raw(statement, params).await
}

async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_with_param_types(statement, params).await
}

async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error> {
self.query_raw_with_param_types(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down Expand Up @@ -222,6 +252,22 @@ impl GenericClient for Transaction<'_> {
self.query_raw(statement, params).await
}

async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_with_param_types(statement, params).await
}

async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error> {
self.query_raw_with_param_types(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
})
}

async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
Expand Down
146 changes: 141 additions & 5 deletions tokio-postgres/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::prepare::get_type;
use crate::types::{BorrowToSql, IsNull};
use crate::{Error, Portal, Row, Statement};
use crate::{Column, Error, Portal, Row, Statement};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
use postgres_protocol::message::backend::{CommandCompleteBody, Message, RowDescriptionBody};
use postgres_protocol::message::frontend;
use postgres_types::Type;
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
Expand Down Expand Up @@ -50,13 +54,125 @@ where
};
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
statement: statement,
responses,
rows_affected: None,
_p: PhantomPinned,
})
}

enum QueryProcessingState {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like this is a lot of boilerplate over a few match responses.next().await? { ... } in a sequence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Empty,
ParseCompleted,
BindCompleted,
ParameterDescribed,
Final(Vec<Column>),
}

/// State machine for processing messages for `query_with_param_types`.
impl QueryProcessingState {
pub async fn process_message(
self,
client: &Arc<InnerClient>,
message: Message,
) -> Result<Self, Error> {
match (self, message) {
(QueryProcessingState::Empty, Message::ParseComplete) => {
Ok(QueryProcessingState::ParseCompleted)
}
(QueryProcessingState::ParseCompleted, Message::BindComplete) => {
Ok(QueryProcessingState::BindCompleted)
}
(QueryProcessingState::BindCompleted, Message::ParameterDescription(_)) => {
Ok(QueryProcessingState::ParameterDescribed)
}
(
QueryProcessingState::ParameterDescribed,
Message::RowDescription(row_description),
) => Self::form_final(client, Some(row_description)).await,
(QueryProcessingState::ParameterDescribed, Message::NoData) => {
Self::form_final(client, None).await
}
(_, Message::ErrorResponse(body)) => Err(Error::db(body)),
_ => Err(Error::unexpected_message()),
}
}

async fn form_final(
client: &Arc<InnerClient>,
row_description: Option<RowDescriptionBody>,
) -> Result<Self, Error> {
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?;
let column = Column {
name: field.name().to_string(),
table_oid: Some(field.table_oid()).filter(|n| *n != 0),
column_id: Some(field.column_id()).filter(|n| *n != 0),
r#type: type_,
};
columns.push(column);
}
}

Ok(Self::Final(columns))
}
}

pub async fn query_with_param_types<'a, P, I>(
client: &Arc<InnerClient>,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
P: BorrowToSql,
I: IntoIterator<Item = (P, Type)>,
I::IntoIter: ExactSizeIterator,
{
let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip();

let params = params.into_iter();

let param_oids = param_types.iter().map(|t| t.oid()).collect::<Vec<_>>();

let params = params.into_iter();

let buf = client.with_buf(|buf| {
frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;

encode_bind_with_statement_name_and_param_types("", &param_types, params, "", buf)?;

frontend::describe(b'S', "", buf).map_err(Error::encode)?;

frontend::execute("", 0, buf).map_err(Error::encode)?;

frontend::sync(buf);

Ok(buf.split().freeze())
})?;

let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;

let mut state = QueryProcessingState::Empty;

loop {
let message = responses.next().await?;

state = state.process_message(client, message).await?;

if let QueryProcessingState::Final(columns) = state {
return Ok(RowStream {
statement: Statement::unnamed(vec![], columns),
responses,
rows_affected: None,
_p: PhantomPinned,
});
}
}
}

pub async fn query_portal(
client: &InnerClient,
portal: &Portal,
Expand Down Expand Up @@ -164,7 +280,27 @@ where
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
encode_bind_with_statement_name_and_param_types(
statement.name(),
statement.params(),
params,
portal,
buf,
)
}

fn encode_bind_with_statement_name_and_param_types<P, I>(
statement_name: &str,
param_types: &[Type],
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();

if param_types.len() != params.len() {
Expand All @@ -181,7 +317,7 @@ where
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
statement_name,
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
Expand Down
13 changes: 13 additions & 0 deletions tokio-postgres/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ struct StatementInner {

impl Drop for StatementInner {
fn drop(&mut self) {
if self.name.is_empty() {
// Unnamed statements don't need to be closed
return;
}
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
Expand Down Expand Up @@ -46,6 +50,15 @@ impl Statement {
}))
}

pub(crate) fn unnamed(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: String::new(),
params,
columns,
}))
}

pub(crate) fn name(&self) -> &str {
&self.0.name
}
Expand Down
Loading