Skip to content

Commit c473f35

Browse files
committed
Add query_raw_txt client method
It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind.
1 parent 0bc41d8 commit c473f35

File tree

4 files changed

+122
-2
lines changed

4 files changed

+122
-2
lines changed

tokio-postgres/src/client.rs

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::copy_both::CopyBothDuplex;
77
use crate::copy_out::CopyOutStream;
88
#[cfg(feature = "runtime")]
99
use crate::keepalive::KeepaliveConfig;
10+
use crate::prepare::get_type;
1011
use crate::query::RowStream;
1112
use crate::simple_query::SimpleQueryStream;
1213
#[cfg(feature = "runtime")]
@@ -20,11 +21,12 @@ use crate::{
2021
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
2122
TransactionBuilder,
2223
};
23-
use bytes::{Buf, BytesMut};
24+
use bytes::{Buf, BytesMut, BufMut};
2425
use fallible_iterator::FallibleIterator;
2526
use futures_channel::mpsc;
2627
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
2728
use parking_lot::Mutex;
29+
use crate::statement::Column;
2830
use postgres_protocol::message::{backend::Message, frontend};
2931
use postgres_types::BorrowToSql;
3032
use std::collections::HashMap;
@@ -374,6 +376,90 @@ impl Client {
374376
query::query(&self.inner, statement, params).await
375377
}
376378

379+
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
380+
/// to save a roundtrip
381+
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
382+
where
383+
S: AsRef<str>,
384+
I: IntoIterator<Item = S>,
385+
I::IntoIter: ExactSizeIterator,
386+
{
387+
let params = params.into_iter();
388+
let parems_len = params.len();
389+
390+
let buf = self.inner.with_buf(|buf| {
391+
// Parse, anonymous portal
392+
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
393+
// Bind, pass params as text, retrieve as binary
394+
match frontend::bind(
395+
"", // empty string selects the unnamed portal
396+
"", // empty string selects the unnamed prepared statement
397+
std::iter::empty(), // all parameters use the default format (text)
398+
params,
399+
|param, buf| {
400+
buf.put_slice(param.as_ref().as_bytes());
401+
Ok(postgres_protocol::IsNull::No)
402+
},
403+
Some(1), // all binary
404+
buf,
405+
) {
406+
Ok(()) => Ok(()),
407+
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), // TODO
408+
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
409+
}?;
410+
411+
// Describe portal to typecast results
412+
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
413+
// Execute
414+
frontend::execute("", 0, buf).map_err(Error::encode)?;
415+
// Sync
416+
frontend::sync(buf);
417+
418+
Ok(buf.split().freeze())
419+
})?;
420+
421+
let mut responses = self.inner.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
422+
423+
// now read the responses
424+
425+
match responses.next().await? {
426+
Message::ParseComplete => {}
427+
_ => return Err(Error::unexpected_message()),
428+
}
429+
match responses.next().await? {
430+
Message::BindComplete => {}
431+
_ => return Err(Error::unexpected_message()),
432+
}
433+
let row_description = match responses.next().await? {
434+
Message::RowDescription(body) => Some(body),
435+
Message::NoData => None,
436+
_ => return Err(Error::unexpected_message()),
437+
};
438+
439+
// construct statement object
440+
441+
let parameters = vec![Type::TEXT; parems_len];
442+
443+
let mut columns = vec![];
444+
if let Some(row_description) = row_description {
445+
let mut it = row_description.fields();
446+
while let Some(field) = it.next().map_err(Error::parse)? {
447+
let type_ = get_type(&self.inner, field.type_oid()).await?;
448+
let column = Column::new(field.name().to_string(), type_);
449+
columns.push(column);
450+
}
451+
}
452+
453+
let statement = Statement::new(
454+
&self.inner,
455+
"".to_owned(),
456+
parameters,
457+
columns
458+
);
459+
460+
Ok(RowStream::new(statement, responses))
461+
}
462+
377463
/// Executes a statement, returning the number of rows modified.
378464
///
379465
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

tokio-postgres/src/prepare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
126126
})
127127
}
128128

129-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
129+
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
130130
if let Some(type_) = Type::from_oid(oid) {
131131
return Ok(type_);
132132
}

tokio-postgres/src/query.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ pin_project! {
207207
}
208208
}
209209

210+
impl RowStream {
211+
/// Creates a new `RowStream`.
212+
pub fn new(statement: Statement, responses: Responses) -> Self {
213+
RowStream {
214+
statement,
215+
responses,
216+
_p: PhantomPinned,
217+
}
218+
}
219+
}
220+
210221
impl Stream for RowStream {
211222
type Item = Result<Row, Error>;
212223

tokio-postgres/tests/test/main.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,29 @@ async fn custom_array() {
251251
}
252252
}
253253

254+
#[tokio::test]
255+
async fn query_raw_txt() {
256+
let client = connect("user=postgres").await;
257+
258+
let rows: Vec<tokio_postgres::Row> = client
259+
.query_raw_txt("SELECT 55 * $1", ["42"])
260+
.await.unwrap()
261+
.try_collect()
262+
.await.unwrap();
263+
264+
assert_eq!(rows.len(), 1);
265+
assert_eq!(rows[0].get::<_, i32>(0), 55 * 42);
266+
267+
let rows: Vec<tokio_postgres::Row> = client
268+
.query_raw_txt("SELECT $1", ["42"])
269+
.await.unwrap()
270+
.try_collect()
271+
.await.unwrap();
272+
273+
assert_eq!(rows.len(), 1);
274+
assert_eq!(rows[0].get::<_, &str>(0), "42");
275+
}
276+
254277
#[tokio::test]
255278
async fn custom_composite() {
256279
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)