Skip to content

Commit 7cd61de

Browse files
committed
Add query_raw_txt client method
It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Also use text protocol for responses -- that allows to grab postgres-provided serializations for types.
1 parent 0bc41d8 commit 7cd61de

File tree

7 files changed

+181
-3
lines changed

7 files changed

+181
-3
lines changed

postgres-types/src/lib.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,22 @@ impl WrongType {
395395
}
396396
}
397397

398+
/// An error indicating that a as_text conversion was attempted on a binary
399+
/// result.
400+
#[derive(Debug)]
401+
pub struct WrongFormat {}
402+
403+
impl Error for WrongFormat {}
404+
405+
impl fmt::Display for WrongFormat {
406+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
407+
write!(
408+
fmt,
409+
"cannot read column as text while it is in binary format"
410+
)
411+
}
412+
}
413+
398414
/// A trait for types that can be created from a Postgres value.
399415
///
400416
/// # Types
@@ -846,7 +862,7 @@ pub trait ToSql: fmt::Debug {
846862
/// Supported Postgres message format types
847863
///
848864
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
849-
#[derive(Clone, Copy, Debug)]
865+
#[derive(Clone, Copy, Debug, PartialEq)]
850866
pub enum Format {
851867
/// Text format (UTF-8)
852868
Text,

tokio-postgres/src/client.rs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ use crate::copy_both::CopyBothDuplex;
77
use crate::copy_out::CopyOutStream;
88
#[cfg(feature = "runtime")]
99
use crate::keepalive::KeepaliveConfig;
10+
use crate::prepare::get_type;
1011
use crate::query::RowStream;
1112
use crate::simple_query::SimpleQueryStream;
13+
use crate::statement::Column;
1214
#[cfg(feature = "runtime")]
1315
use crate::tls::MakeTlsConnect;
1416
use crate::tls::TlsConnect;
@@ -20,7 +22,7 @@ use crate::{
2022
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
2123
TransactionBuilder,
2224
};
23-
use bytes::{Buf, BytesMut};
25+
use bytes::{Buf, BufMut, BytesMut};
2426
use fallible_iterator::FallibleIterator;
2527
use futures_channel::mpsc;
2628
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
@@ -374,6 +376,87 @@ impl Client {
374376
query::query(&self.inner, statement, params).await
375377
}
376378

379+
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
380+
/// to save a roundtrip
381+
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
382+
where
383+
S: AsRef<str>,
384+
I: IntoIterator<Item = S>,
385+
I::IntoIter: ExactSizeIterator,
386+
{
387+
let params = params.into_iter();
388+
let params_len = params.len();
389+
390+
let buf = self.inner.with_buf(|buf| {
391+
// Parse, anonymous portal
392+
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
393+
// Bind, pass params as text, retrieve as binary
394+
match frontend::bind(
395+
"", // empty string selects the unnamed portal
396+
"", // empty string selects the unnamed prepared statement
397+
std::iter::empty(), // all parameters use the default format (text)
398+
params,
399+
|param, buf| {
400+
buf.put_slice(param.as_ref().as_bytes());
401+
Ok(postgres_protocol::IsNull::No)
402+
},
403+
Some(0), // all text
404+
buf,
405+
) {
406+
Ok(()) => Ok(()),
407+
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
408+
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
409+
}?;
410+
411+
// Describe portal to typecast results
412+
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
413+
// Execute
414+
frontend::execute("", 0, buf).map_err(Error::encode)?;
415+
// Sync
416+
frontend::sync(buf);
417+
418+
Ok(buf.split().freeze())
419+
})?;
420+
421+
let mut responses = self
422+
.inner
423+
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
424+
425+
// now read the responses
426+
427+
match responses.next().await? {
428+
Message::ParseComplete => {}
429+
_ => return Err(Error::unexpected_message()),
430+
}
431+
match responses.next().await? {
432+
Message::BindComplete => {}
433+
_ => return Err(Error::unexpected_message()),
434+
}
435+
let row_description = match responses.next().await? {
436+
Message::RowDescription(body) => Some(body),
437+
Message::NoData => None,
438+
_ => return Err(Error::unexpected_message()),
439+
};
440+
441+
// construct statement object
442+
443+
let parameters = vec![Type::UNKNOWN; params_len];
444+
445+
let mut columns = vec![];
446+
if let Some(row_description) = row_description {
447+
let mut it = row_description.fields();
448+
while let Some(field) = it.next().map_err(Error::parse)? {
449+
let type_ = get_type(&self.inner, field.type_oid()).await?;
450+
let column = Column::new(field.name().to_string(), type_);
451+
columns.push(column);
452+
}
453+
}
454+
455+
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
456+
457+
Ok(RowStream::new(statement, responses))
458+
}
459+
377460
/// Executes a statement, returning the number of rows modified.
378461
///
379462
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

tokio-postgres/src/prepare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
126126
})
127127
}
128128

129-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
129+
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
130130
if let Some(type_) = Type::from_oid(oid) {
131131
return Ok(type_);
132132
}

tokio-postgres/src/query.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ pin_project! {
207207
}
208208
}
209209

210+
impl RowStream {
211+
/// Creates a new `RowStream`.
212+
pub fn new(statement: Statement, responses: Responses) -> Self {
213+
RowStream {
214+
statement,
215+
responses,
216+
_p: PhantomPinned,
217+
}
218+
}
219+
}
220+
210221
impl Stream for RowStream {
211222
type Item = Result<Row, Error>;
212223

tokio-postgres/src/row.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType};
77
use crate::{Error, Statement};
88
use fallible_iterator::FallibleIterator;
99
use postgres_protocol::message::backend::DataRowBody;
10+
use postgres_types::{Format, WrongFormat};
1011
use std::fmt;
1112
use std::ops::Range;
1213
use std::str;
@@ -187,6 +188,22 @@ impl Row {
187188
let range = self.ranges[idx].to_owned()?;
188189
Some(&self.body.buffer()[range])
189190
}
191+
192+
/// Interpret the column at the given index as text
193+
///
194+
/// Useful when using query_raw_txt() which sets text transfer mode
195+
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
196+
if self.statement.output_format() == Format::Text {
197+
match self.col_buffer(idx) {
198+
Some(raw) => {
199+
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
200+
}
201+
None => Ok(None),
202+
}
203+
} else {
204+
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
205+
}
206+
}
190207
}
191208

192209
impl AsName for SimpleColumn {

tokio-postgres/src/statement.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
44
use crate::types::Type;
55
use postgres_protocol::message::frontend;
6+
use postgres_types::Format;
67
use std::{
78
fmt,
89
sync::{Arc, Weak},
@@ -13,6 +14,7 @@ struct StatementInner {
1314
name: String,
1415
params: Vec<Type>,
1516
columns: Vec<Column>,
17+
output_format: Format,
1618
}
1719

1820
impl Drop for StatementInner {
@@ -46,6 +48,22 @@ impl Statement {
4648
name,
4749
params,
4850
columns,
51+
output_format: Format::Binary,
52+
}))
53+
}
54+
55+
pub(crate) fn new_text(
56+
inner: &Arc<InnerClient>,
57+
name: String,
58+
params: Vec<Type>,
59+
columns: Vec<Column>,
60+
) -> Statement {
61+
Statement(Arc::new(StatementInner {
62+
client: Arc::downgrade(inner),
63+
name,
64+
params,
65+
columns,
66+
output_format: Format::Text,
4967
}))
5068
}
5169

@@ -62,6 +80,11 @@ impl Statement {
6280
pub fn columns(&self) -> &[Column] {
6381
&self.0.columns
6482
}
83+
84+
/// Returns output format for the statement.
85+
pub fn output_format(&self) -> Format {
86+
self.0.output_format
87+
}
6588
}
6689

6790
/// Information about a column of a query.

tokio-postgres/tests/test/main.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,34 @@ async fn custom_array() {
251251
}
252252
}
253253

254+
#[tokio::test]
255+
async fn query_raw_txt() {
256+
let client = connect("user=postgres").await;
257+
258+
let rows: Vec<tokio_postgres::Row> = client
259+
.query_raw_txt("SELECT 55 * $1", ["42"])
260+
.await
261+
.unwrap()
262+
.try_collect()
263+
.await
264+
.unwrap();
265+
266+
assert_eq!(rows.len(), 1);
267+
let res: i32 = rows[0].as_text(0).unwrap().parse::<i32>().unwrap();
268+
assert_eq!(res, 55 * 42);
269+
270+
let rows: Vec<tokio_postgres::Row> = client
271+
.query_raw_txt("SELECT $1", ["42"])
272+
.await
273+
.unwrap()
274+
.try_collect()
275+
.await
276+
.unwrap();
277+
278+
assert_eq!(rows.len(), 1);
279+
assert_eq!(rows[0].get::<_, &str>(0), "42");
280+
}
281+
254282
#[tokio::test]
255283
async fn custom_composite() {
256284
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)