Skip to content

Commit 2629d23

Browse files
fix(io): add connection pool semaphore to Azure Blob Storage backend (#6305)
Changes Made Add a connection pool semaphore to AzureBlobSource, matching the existing pattern in S3/GCS/TOS backends. Azure was the only major backend missing this, causing unbounded concurrent connections (400-800+) when reading multiple large parquet files in parallel. - Add max_connections_per_io_thread (default 8) to AzureConfig - Add connection_pool_sema to AzureBlobSource with permit lifecycle tied to GetResult::Stream - Extract get_size_internal() to avoid deadlock for GetRange::Suffix (Azure SDK doesn't support native suffix ranges) - Update Python bindings, type stubs, SQL config Related Issues Closes #6279
1 parent cd92fe3 commit 2629d23

7 files changed

Lines changed: 239 additions & 29 deletions

File tree

daft/daft/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ class AzureConfig:
640640
anonymous: bool | None
641641
endpoint_url: str | None = None
642642
use_ssl: bool | None = None
643+
max_connections: int
643644

644645
def __init__(
645646
self,
@@ -654,6 +655,7 @@ class AzureConfig:
654655
anonymous: bool | None = None,
655656
endpoint_url: str | None = None,
656657
use_ssl: bool | None = None,
658+
max_connections: int | None = None,
657659
): ...
658660
def replace(
659661
self,
@@ -668,6 +670,7 @@ class AzureConfig:
668670
anonymous: bool | None = None,
669671
endpoint_url: str | None = None,
670672
use_ssl: bool | None = None,
673+
max_connections: int | None = None,
671674
) -> AzureConfig:
672675
"""Replaces values if provided, returning a new AzureConfig."""
673676
...

src/common/io-config/src/azure.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub struct AzureConfig {
1717
pub anonymous: bool,
1818
pub endpoint_url: Option<String>,
1919
pub use_ssl: bool,
20+
pub max_connections_per_io_thread: u32,
2021
}
2122

2223
impl Default for AzureConfig {
@@ -33,6 +34,7 @@ impl Default for AzureConfig {
3334
anonymous: false,
3435
endpoint_url: None,
3536
use_ssl: true,
37+
max_connections_per_io_thread: 8,
3638
}
3739
}
3840
}
@@ -71,6 +73,10 @@ impl AzureConfig {
7173
res.push(format!("Endpoint URL = {endpoint_url}"));
7274
}
7375
res.push(format!("Use SSL = {}", self.use_ssl));
76+
res.push(format!(
77+
"Max connections = {}",
78+
self.max_connections_per_io_thread
79+
));
7480
res
7581
}
7682
}
@@ -90,7 +96,8 @@ impl Display for AzureConfig {
9096
use_fabric_endpoint: {:?}
9197
anonymous: {:?}
9298
endpoint_url: {:?}
93-
use_ssl: {:?}",
99+
use_ssl: {:?}
100+
max_connections_per_io_thread: {:?}",
94101
self.storage_account,
95102
self.access_key,
96103
self.sas_token,
@@ -101,7 +108,8 @@ impl Display for AzureConfig {
101108
self.use_fabric_endpoint,
102109
self.anonymous,
103110
self.endpoint_url,
104-
self.use_ssl
111+
self.use_ssl,
112+
self.max_connections_per_io_thread
105113
)
106114
}
107115
}

src/common/io-config/src/python.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub struct S3Credentials {
9292
/// anonymous (bool, optional): Whether or not to use "anonymous mode", which will access Azure without any credentials
9393
/// endpoint_url (str, optional): Custom URL to the Azure endpoint, e.g. ``https://my-account-name.blob.core.windows.net``. Overrides `use_fabric_endpoint` if set
9494
/// use_ssl (bool, optional): Whether or not to use SSL, which require accessing Azure over HTTPS rather than HTTP, defaults to True
95+
/// max_connections (int, optional): Maximum number of connections to Azure at any time per io thread, defaults to 8
9596
///
9697
/// Examples:
9798
/// >>> io_config = IOConfig(azure=AzureConfig(storage_account="dafttestdata", access_key="xxx"))
@@ -924,7 +925,8 @@ impl AzureConfig {
924925
use_fabric_endpoint=None,
925926
anonymous=None,
926927
endpoint_url=None,
927-
use_ssl=None
928+
use_ssl=None,
929+
max_connections=None
928930
))]
929931
pub fn new(
930932
storage_account: Option<String>,
@@ -938,6 +940,7 @@ impl AzureConfig {
938940
anonymous: Option<bool>,
939941
endpoint_url: Option<String>,
940942
use_ssl: Option<bool>,
943+
max_connections: Option<u32>,
941944
) -> Self {
942945
let def = crate::AzureConfig::default();
943946
Self {
@@ -955,6 +958,8 @@ impl AzureConfig {
955958
anonymous: anonymous.unwrap_or(def.anonymous),
956959
endpoint_url: endpoint_url.or(def.endpoint_url),
957960
use_ssl: use_ssl.unwrap_or(def.use_ssl),
961+
max_connections_per_io_thread: max_connections
962+
.unwrap_or(def.max_connections_per_io_thread),
958963
},
959964
}
960965
}
@@ -972,7 +977,8 @@ impl AzureConfig {
972977
use_fabric_endpoint=None,
973978
anonymous=None,
974979
endpoint_url=None,
975-
use_ssl=None
980+
use_ssl=None,
981+
max_connections=None
976982
))]
977983
pub fn replace(
978984
&self,
@@ -987,6 +993,7 @@ impl AzureConfig {
987993
anonymous: Option<bool>,
988994
endpoint_url: Option<String>,
989995
use_ssl: Option<bool>,
996+
max_connections: Option<u32>,
990997
) -> Self {
991998
Self {
992999
config: crate::AzureConfig {
@@ -1005,6 +1012,8 @@ impl AzureConfig {
10051012
anonymous: anonymous.unwrap_or(self.config.anonymous),
10061013
endpoint_url: endpoint_url.or_else(|| self.config.endpoint_url.clone()),
10071014
use_ssl: use_ssl.unwrap_or(self.config.use_ssl),
1015+
max_connections_per_io_thread: max_connections
1016+
.unwrap_or(self.config.max_connections_per_io_thread),
10081017
},
10091018
}
10101019
}
@@ -1085,6 +1094,12 @@ impl AzureConfig {
10851094
pub fn use_ssl(&self) -> PyResult<bool> {
10861095
Ok(self.config.use_ssl)
10871096
}
1097+
1098+
/// Maximum number of connections per IO thread for Azure
1099+
#[getter]
1100+
pub fn max_connections(&self) -> PyResult<u32> {
1101+
Ok(self.config.max_connections_per_io_thread)
1102+
}
10881103
}
10891104

10901105
#[pymethods]

src/daft-io/src/azure_blob.rs

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use azure_storage_blobs::{
1212
prelude::*,
1313
};
1414
use common_io_config::AzureConfig;
15+
use common_runtime::get_io_pool_num_threads;
1516
use derive_builder::Builder;
1617
use futures::{StreamExt, TryStreamExt, stream::BoxStream};
1718
use snafu::{IntoError, ResultExt, Snafu};
@@ -70,6 +71,9 @@ enum Error {
7071

7172
#[snafu(display("Not a File: \"{}\"", path))]
7273
NotAFile { path: String },
74+
75+
#[snafu(display("Unable to grab semaphore. {}", source))]
76+
UnableToGrabSemaphore { source: tokio::sync::AcquireError },
7377
}
7478

7579
#[derive(Builder)]
@@ -164,6 +168,7 @@ impl From<Error> for super::Error {
164168

165169
pub struct AzureBlobSource {
166170
blob_client: Arc<BlobServiceClient>,
171+
connection_pool_sema: Arc<tokio::sync::Semaphore>,
167172
}
168173

169174
impl AzureBlobSource {
@@ -262,8 +267,14 @@ impl AzureBlobSource {
262267
BlobServiceClient::new(storage_account, storage_credentials)
263268
};
264269

270+
let connection_pool_sema = Arc::new(tokio::sync::Semaphore::new(
271+
(config.max_connections_per_io_thread.max(1) as usize)
272+
* get_io_pool_num_threads().expect("Should be running in tokio pool"),
273+
));
274+
265275
Ok(Self {
266276
blob_client: blob_client.into(),
277+
connection_pool_sema,
267278
}
268279
.into())
269280
}
@@ -519,6 +530,38 @@ impl AzureBlobSource {
519530
},
520531
}
521532
}
533+
534+
/// Internal get_size that does NOT acquire the semaphore.
535+
/// Used by `get()` which already holds a permit (avoids deadlock for GetRange::Suffix).
536+
async fn get_size_internal(
537+
&self,
538+
uri: &str,
539+
io_stats: Option<IOStatsRef>,
540+
) -> super::Result<usize> {
541+
let parsed_uri = parse_azure_uri(uri)?;
542+
let (container, key) = parsed_uri
543+
.container_and_key
544+
.ok_or_else(|| Error::InvalidUrl {
545+
path: uri.into(),
546+
source: url::ParseError::EmptyHost,
547+
})?;
548+
549+
if key.is_empty() {
550+
return Err(Error::NotAFile { path: uri.into() }.into());
551+
}
552+
553+
let container_client = self.blob_client.container_client(container);
554+
let blob_client = container_client.blob_client(key);
555+
let metadata = blob_client
556+
.get_properties()
557+
.await
558+
.context(UnableToOpenFileSnafu::<String> { path: uri.into() })?;
559+
if let Some(is) = io_stats.as_ref() {
560+
is.mark_head_requests(1);
561+
}
562+
563+
Ok(metadata.blob.properties.content_length as usize)
564+
}
522565
}
523566

524567
#[async_trait]
@@ -533,6 +576,13 @@ impl ObjectSource for AzureBlobSource {
533576
range: Option<GetRange>,
534577
io_stats: Option<IOStatsRef>,
535578
) -> super::Result<GetResult> {
579+
let permit = self
580+
.connection_pool_sema
581+
.clone()
582+
.acquire_owned()
583+
.await
584+
.context(UnableToGrabSemaphoreSnafu)?;
585+
536586
let parsed_uri = parse_azure_uri(uri)?;
537587
let (container, key) = parsed_uri
538588
.container_and_key
@@ -552,10 +602,11 @@ impl ObjectSource for AzureBlobSource {
552602
range.validate().context(InvalidRangeRequestSnafu)?;
553603
match range {
554604
GetRange::Bounded(u) => request_builder.range(u),
555-
// Note: if n is greater than file size, Azure will whole content.
605+
// Note: if n is greater than file size, Azure will return the whole content.
556606
GetRange::Offset(n) => request_builder.range(n..),
557607
GetRange::Suffix(n) => {
558-
let size = self.get_size(uri, io_stats.clone()).await?;
608+
// Use get_size_internal to avoid deadlock (we already hold a permit)
609+
let size = self.get_size_internal(uri, io_stats.clone()).await?;
559610
request_builder.range(size.saturating_sub(n)..)
560611
}
561612
}
@@ -580,7 +631,7 @@ impl ObjectSource for AzureBlobSource {
580631
Ok(GetResult::Stream(
581632
io_stats_on_bytestream(Box::pin(stream), io_stats),
582633
None,
583-
None,
634+
Some(permit),
584635
None,
585636
))
586637
}
@@ -595,29 +646,12 @@ impl ObjectSource for AzureBlobSource {
595646
}
596647

597648
async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
598-
let parsed_uri = parse_azure_uri(uri)?;
599-
let (container, key) = parsed_uri
600-
.container_and_key
601-
.ok_or_else(|| Error::InvalidUrl {
602-
path: uri.into(),
603-
source: url::ParseError::EmptyHost,
604-
})?;
605-
606-
if key.is_empty() {
607-
return Err(Error::NotAFile { path: uri.into() }.into());
608-
}
609-
610-
let container_client = self.blob_client.container_client(container);
611-
let blob_client = container_client.blob_client(key);
612-
let metadata = blob_client
613-
.get_properties()
649+
let _permit = self
650+
.connection_pool_sema
651+
.acquire()
614652
.await
615-
.context(UnableToOpenFileSnafu::<String> { path: uri.into() })?;
616-
if let Some(is) = io_stats.as_ref() {
617-
is.mark_head_requests(1);
618-
}
619-
620-
Ok(metadata.blob.properties.content_length as usize)
653+
.context(UnableToGrabSemaphoreSnafu)?;
654+
self.get_size_internal(uri, io_stats).await
621655
}
622656

623657
async fn glob(

src/daft-sql/src/modules/config.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ impl SQLFunction for AzureConfigFunction {
246246
"anonymous",
247247
"endpoint_url",
248248
"use_ssl",
249+
"max_connections_per_io_thread",
249250
],
250251
0,
251252
)?;
@@ -261,6 +262,9 @@ impl SQLFunction for AzureConfigFunction {
261262
let anonymous = args.try_get_named::<bool>("anonymous")?;
262263
let endpoint_url = args.try_get_named::<String>("endpoint_url")?;
263264
let use_ssl = args.try_get_named::<bool>("use_ssl")?;
265+
let max_connections_per_io_thread = args
266+
.try_get_named::<i64>("max_connections_per_io_thread")?
267+
.map(|t| t as u32);
264268

265269
let entries = vec![
266270
("variant".to_string(), "azure".into()),
@@ -275,6 +279,7 @@ impl SQLFunction for AzureConfigFunction {
275279
item!(anonymous),
276280
item!(endpoint_url),
277281
item!(use_ssl),
282+
item!(max_connections_per_io_thread),
278283
]
279284
.into_iter()
280285
.collect::<_>();
@@ -299,6 +304,7 @@ impl SQLFunction for AzureConfigFunction {
299304
"anonymous",
300305
"endpoint_url",
301306
"use_ssl",
307+
"max_connections_per_io_thread",
302308
]
303309
}
304310
}
@@ -560,6 +566,8 @@ pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult<IOConfig> {
560566
let anonymous = get_value!("anonymous", Boolean)?;
561567
let endpoint_url = get_value!("endpoint_url", Utf8)?;
562568
let use_ssl = get_value!("use_ssl", Boolean)?;
569+
let max_connections_per_io_thread =
570+
get_value!("max_connections_per_io_thread", UInt32)?;
563571

564572
let default = AzureConfig::default();
565573

@@ -576,6 +584,8 @@ pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult<IOConfig> {
576584
anonymous: anonymous.unwrap_or(default.anonymous),
577585
endpoint_url,
578586
use_ssl: use_ssl.unwrap_or(default.use_ssl),
587+
max_connections_per_io_thread: max_connections_per_io_thread
588+
.unwrap_or(default.max_connections_per_io_thread),
579589
},
580590
..Default::default()
581591
})

tests/dataframe/test_morsels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def test_batch_size_from_udf_propagated_through_ops_to_scan():
123123
| anonymous: false
124124
| endpoint_url: None
125125
| use_ssl: true
126+
| max_connections_per_io_thread: 8
126127
| GCSConfig
127128
| project_id: None
128129
| anonymous: false

0 commit comments

Comments
 (0)