diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 4153fa250..fdb5e6359 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -185,6 +185,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -217,6 +218,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, replication_mode: None, + max_backend_message_size: None, } } @@ -472,6 +474,17 @@ impl Config { self.replication_mode } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -586,6 +599,14 @@ impl Config { self.replication_mode(mode); } } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index ddfca2894..0beead11f 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -90,7 +90,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 3213a6dad..551f6ec5c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -280,6 +280,30 @@ async fn query_raw_txt() { assert!(rows[0].body_len() > 0); } +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + #[tokio::test] async fn command_tag() { let client = connect("user=postgres").await;