Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ class AzureConfig:
anonymous: bool | None
endpoint_url: str | None = None
use_ssl: bool | None = None
max_connections: int

def __init__(
self,
Expand All @@ -654,6 +655,7 @@ class AzureConfig:
anonymous: bool | None = None,
endpoint_url: str | None = None,
use_ssl: bool | None = None,
max_connections: int | None = None,
): ...
def replace(
self,
Expand All @@ -668,6 +670,7 @@ class AzureConfig:
anonymous: bool | None = None,
endpoint_url: str | None = None,
use_ssl: bool | None = None,
max_connections: int | None = None,
) -> AzureConfig:
"""Replaces values if provided, returning a new AzureConfig."""
...
Expand Down
12 changes: 10 additions & 2 deletions src/common/io-config/src/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct AzureConfig {
pub anonymous: bool,
pub endpoint_url: Option<String>,
pub use_ssl: bool,
pub max_connections_per_io_thread: u32,
}

impl Default for AzureConfig {
Expand All @@ -33,6 +34,7 @@ impl Default for AzureConfig {
anonymous: false,
endpoint_url: None,
use_ssl: true,
max_connections_per_io_thread: 8,
}
}
}
Expand Down Expand Up @@ -71,6 +73,10 @@ impl AzureConfig {
res.push(format!("Endpoint URL = {endpoint_url}"));
}
res.push(format!("Use SSL = {}", self.use_ssl));
res.push(format!(
"Max connections = {}",
self.max_connections_per_io_thread
));
res
}
}
Expand All @@ -90,7 +96,8 @@ impl Display for AzureConfig {
use_fabric_endpoint: {:?}
anonymous: {:?}
endpoint_url: {:?}
use_ssl: {:?}",
use_ssl: {:?}
max_connections_per_io_thread: {:?}",
self.storage_account,
self.access_key,
self.sas_token,
Expand All @@ -101,7 +108,8 @@ impl Display for AzureConfig {
self.use_fabric_endpoint,
self.anonymous,
self.endpoint_url,
self.use_ssl
self.use_ssl,
self.max_connections_per_io_thread
)
}
}
19 changes: 17 additions & 2 deletions src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub struct S3Credentials {
/// anonymous (bool, optional): Whether or not to use "anonymous mode", which will access Azure without any credentials
/// endpoint_url (str, optional): Custom URL to the Azure endpoint, e.g. ``https://my-account-name.blob.core.windows.net``. Overrides `use_fabric_endpoint` if set
/// use_ssl (bool, optional): Whether or not to use SSL, which require accessing Azure over HTTPS rather than HTTP, defaults to True
/// max_connections (int, optional): Maximum number of connections to Azure at any time per io thread, defaults to 8
///
/// Examples:
/// >>> io_config = IOConfig(azure=AzureConfig(storage_account="dafttestdata", access_key="xxx"))
Expand Down Expand Up @@ -924,7 +925,8 @@ impl AzureConfig {
use_fabric_endpoint=None,
anonymous=None,
endpoint_url=None,
use_ssl=None
use_ssl=None,
max_connections=None
))]
pub fn new(
storage_account: Option<String>,
Expand All @@ -938,6 +940,7 @@ impl AzureConfig {
anonymous: Option<bool>,
endpoint_url: Option<String>,
use_ssl: Option<bool>,
max_connections: Option<u32>,
) -> Self {
let def = crate::AzureConfig::default();
Self {
Expand All @@ -955,6 +958,8 @@ impl AzureConfig {
anonymous: anonymous.unwrap_or(def.anonymous),
endpoint_url: endpoint_url.or(def.endpoint_url),
use_ssl: use_ssl.unwrap_or(def.use_ssl),
max_connections_per_io_thread: max_connections
.unwrap_or(def.max_connections_per_io_thread),
},
}
}
Expand All @@ -972,7 +977,8 @@ impl AzureConfig {
use_fabric_endpoint=None,
anonymous=None,
endpoint_url=None,
use_ssl=None
use_ssl=None,
max_connections=None
))]
pub fn replace(
&self,
Expand All @@ -987,6 +993,7 @@ impl AzureConfig {
anonymous: Option<bool>,
endpoint_url: Option<String>,
use_ssl: Option<bool>,
max_connections: Option<u32>,
) -> Self {
Self {
config: crate::AzureConfig {
Expand All @@ -1005,6 +1012,8 @@ impl AzureConfig {
anonymous: anonymous.unwrap_or(self.config.anonymous),
endpoint_url: endpoint_url.or_else(|| self.config.endpoint_url.clone()),
use_ssl: use_ssl.unwrap_or(self.config.use_ssl),
max_connections_per_io_thread: max_connections
.unwrap_or(self.config.max_connections_per_io_thread),
},
}
}
Expand Down Expand Up @@ -1085,6 +1094,12 @@ impl AzureConfig {
pub fn use_ssl(&self) -> PyResult<bool> {
Ok(self.config.use_ssl)
}

/// Maximum number of connections per IO thread for Azure
#[getter]
pub fn max_connections(&self) -> PyResult<u32> {
Ok(self.config.max_connections_per_io_thread)
}
}

#[pymethods]
Expand Down
84 changes: 59 additions & 25 deletions src/daft-io/src/azure_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use azure_storage_blobs::{
prelude::*,
};
use common_io_config::AzureConfig;
use common_runtime::get_io_pool_num_threads;
use derive_builder::Builder;
use futures::{StreamExt, TryStreamExt, stream::BoxStream};
use snafu::{IntoError, ResultExt, Snafu};
Expand Down Expand Up @@ -70,6 +71,9 @@ enum Error {

#[snafu(display("Not a File: \"{}\"", path))]
NotAFile { path: String },

#[snafu(display("Unable to grab semaphore. {}", source))]
UnableToGrabSemaphore { source: tokio::sync::AcquireError },
}

#[derive(Builder)]
Expand Down Expand Up @@ -164,6 +168,7 @@ impl From<Error> for super::Error {

pub struct AzureBlobSource {
blob_client: Arc<BlobServiceClient>,
connection_pool_sema: Arc<tokio::sync::Semaphore>,
}

impl AzureBlobSource {
Expand Down Expand Up @@ -262,8 +267,14 @@ impl AzureBlobSource {
BlobServiceClient::new(storage_account, storage_credentials)
};

let connection_pool_sema = Arc::new(tokio::sync::Semaphore::new(
(config.max_connections_per_io_thread.max(1) as usize)
* get_io_pool_num_threads().expect("Should be running in tokio pool"),
));

Ok(Self {
blob_client: blob_client.into(),
connection_pool_sema,
}
.into())
}
Expand Down Expand Up @@ -519,6 +530,38 @@ impl AzureBlobSource {
},
}
}

/// Internal get_size that does NOT acquire the semaphore.
/// Used by `get()` which already holds a permit (avoids deadlock for GetRange::Suffix).
async fn get_size_internal(
&self,
uri: &str,
io_stats: Option<IOStatsRef>,
) -> super::Result<usize> {
let parsed_uri = parse_azure_uri(uri)?;
let (container, key) = parsed_uri
.container_and_key
.ok_or_else(|| Error::InvalidUrl {
path: uri.into(),
source: url::ParseError::EmptyHost,
})?;

if key.is_empty() {
return Err(Error::NotAFile { path: uri.into() }.into());
}

let container_client = self.blob_client.container_client(container);
let blob_client = container_client.blob_client(key);
let metadata = blob_client
.get_properties()
.await
.context(UnableToOpenFileSnafu::<String> { path: uri.into() })?;
if let Some(is) = io_stats.as_ref() {
is.mark_head_requests(1);
}

Ok(metadata.blob.properties.content_length as usize)
}
}

#[async_trait]
Expand All @@ -533,6 +576,13 @@ impl ObjectSource for AzureBlobSource {
range: Option<GetRange>,
io_stats: Option<IOStatsRef>,
) -> super::Result<GetResult> {
let permit = self
.connection_pool_sema
.clone()
.acquire_owned()
.await
.context(UnableToGrabSemaphoreSnafu)?;

let parsed_uri = parse_azure_uri(uri)?;
let (container, key) = parsed_uri
.container_and_key
Expand All @@ -552,10 +602,11 @@ impl ObjectSource for AzureBlobSource {
range.validate().context(InvalidRangeRequestSnafu)?;
match range {
GetRange::Bounded(u) => request_builder.range(u),
// Note: if n is greater than file size, Azure will whole content.
// Note: if n is greater than file size, Azure will return the whole content.
GetRange::Offset(n) => request_builder.range(n..),
GetRange::Suffix(n) => {
let size = self.get_size(uri, io_stats.clone()).await?;
// Use get_size_internal to avoid deadlock (we already hold a permit)
let size = self.get_size_internal(uri, io_stats.clone()).await?;
request_builder.range(size.saturating_sub(n)..)
}
}
Expand All @@ -580,7 +631,7 @@ impl ObjectSource for AzureBlobSource {
Ok(GetResult::Stream(
io_stats_on_bytestream(Box::pin(stream), io_stats),
None,
None,
Some(permit),
None,
))
}
Expand All @@ -595,29 +646,12 @@ impl ObjectSource for AzureBlobSource {
}

async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
let parsed_uri = parse_azure_uri(uri)?;
let (container, key) = parsed_uri
.container_and_key
.ok_or_else(|| Error::InvalidUrl {
path: uri.into(),
source: url::ParseError::EmptyHost,
})?;

if key.is_empty() {
return Err(Error::NotAFile { path: uri.into() }.into());
}

let container_client = self.blob_client.container_client(container);
let blob_client = container_client.blob_client(key);
let metadata = blob_client
.get_properties()
let _permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToOpenFileSnafu::<String> { path: uri.into() })?;
if let Some(is) = io_stats.as_ref() {
is.mark_head_requests(1);
}

Ok(metadata.blob.properties.content_length as usize)
.context(UnableToGrabSemaphoreSnafu)?;
self.get_size_internal(uri, io_stats).await
}

async fn glob(
Expand Down
10 changes: 10 additions & 0 deletions src/daft-sql/src/modules/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ impl SQLFunction for AzureConfigFunction {
"anonymous",
"endpoint_url",
"use_ssl",
"max_connections_per_io_thread",
],
0,
)?;
Expand All @@ -261,6 +262,9 @@ impl SQLFunction for AzureConfigFunction {
let anonymous = args.try_get_named::<bool>("anonymous")?;
let endpoint_url = args.try_get_named::<String>("endpoint_url")?;
let use_ssl = args.try_get_named::<bool>("use_ssl")?;
let max_connections_per_io_thread = args
.try_get_named::<i64>("max_connections_per_io_thread")?
.map(|t| t as u32);

let entries = vec![
("variant".to_string(), "azure".into()),
Expand All @@ -275,6 +279,7 @@ impl SQLFunction for AzureConfigFunction {
item!(anonymous),
item!(endpoint_url),
item!(use_ssl),
item!(max_connections_per_io_thread),
]
.into_iter()
.collect::<_>();
Expand All @@ -299,6 +304,7 @@ impl SQLFunction for AzureConfigFunction {
"anonymous",
"endpoint_url",
"use_ssl",
"max_connections_per_io_thread",
]
}
}
Expand Down Expand Up @@ -560,6 +566,8 @@ pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult<IOConfig> {
let anonymous = get_value!("anonymous", Boolean)?;
let endpoint_url = get_value!("endpoint_url", Utf8)?;
let use_ssl = get_value!("use_ssl", Boolean)?;
let max_connections_per_io_thread =
get_value!("max_connections_per_io_thread", UInt32)?;

let default = AzureConfig::default();

Expand All @@ -576,6 +584,8 @@ pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult<IOConfig> {
anonymous: anonymous.unwrap_or(default.anonymous),
endpoint_url,
use_ssl: use_ssl.unwrap_or(default.use_ssl),
max_connections_per_io_thread: max_connections_per_io_thread
.unwrap_or(default.max_connections_per_io_thread),
},
..Default::default()
})
Expand Down
1 change: 1 addition & 0 deletions tests/dataframe/test_morsels.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def test_batch_size_from_udf_propagated_through_ops_to_scan():
| anonymous: false
| endpoint_url: None
| use_ssl: true
| max_connections_per_io_thread: 8
| GCSConfig
| project_id: None
| anonymous: false
Expand Down
Loading
Loading