Skip to content

Commit b114de3

Browse files
committed
Catch command tag
1 parent 7cd61de commit b114de3

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

tokio-postgres/src/query.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ where
5252
Ok(RowStream {
5353
statement,
5454
responses,
55+
command_tag: None,
5556
_p: PhantomPinned,
5657
})
5758
}
@@ -72,6 +73,7 @@ pub async fn query_portal(
7273
Ok(RowStream {
7374
statement: portal.statement().clone(),
7475
responses,
76+
command_tag: None,
7577
_p: PhantomPinned,
7678
})
7779
}
@@ -202,6 +204,7 @@ pin_project! {
202204
pub struct RowStream {
203205
statement: Statement,
204206
responses: Responses,
207+
command_tag: Option<String>,
205208
#[pin]
206209
_p: PhantomPinned,
207210
}
@@ -213,6 +216,7 @@ impl RowStream {
213216
RowStream {
214217
statement,
215218
responses,
219+
command_tag: None,
216220
_p: PhantomPinned,
217221
}
218222
}
@@ -228,12 +232,24 @@ impl Stream for RowStream {
228232
Message::DataRow(body) => {
229233
return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
230234
}
231-
Message::EmptyQueryResponse
232-
| Message::CommandComplete(_)
233-
| Message::PortalSuspended => {}
235+
Message::EmptyQueryResponse | Message::PortalSuspended => {}
236+
Message::CommandComplete(body) => {
237+
if let Ok(tag) = body.tag() {
238+
*this.command_tag = Some(tag.to_string());
239+
}
240+
}
234241
Message::ReadyForQuery(_) => return Poll::Ready(None),
235242
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
236243
}
237244
}
238245
}
239246
}
247+
248+
impl RowStream {
249+
/// Returns the command tag of this query.
250+
///
251+
/// This is only available after the stream has been exhausted.
252+
pub fn command_tag(&self) -> Option<String> {
253+
self.command_tag.clone()
254+
}
255+
}

tokio-postgres/tests/test/main.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ async fn query_raw_txt() {
264264
.unwrap();
265265

266266
assert_eq!(rows.len(), 1);
267-
let res: i32 = rows[0].as_text(0).unwrap().parse::<i32>().unwrap();
267+
let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::<i32>().unwrap();
268268
assert_eq!(res, 55 * 42);
269269

270270
let rows: Vec<tokio_postgres::Row> = client
@@ -279,6 +279,25 @@ async fn query_raw_txt() {
279279
assert_eq!(rows[0].get::<_, &str>(0), "42");
280280
}
281281

282+
#[tokio::test]
283+
async fn command_tag() {
284+
let client = connect("user=postgres").await;
285+
286+
let row_stream = client
287+
.query_raw_txt("select unnest('{1,2,3}'::int[]);", [])
288+
.await
289+
.unwrap();
290+
291+
pin_mut!(row_stream);
292+
293+
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
294+
while let Some(row) = row_stream.next().await {
295+
rows.push(row.unwrap());
296+
}
297+
298+
assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string()));
299+
}
300+
282301
#[tokio::test]
283302
async fn custom_composite() {
284303
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)