Skip to content

Commit eccd84b

Browse files
authored
Merge pull request #62 from stackhpc/missing
Add support for missing data
2 parents 691476e + 465920b commit eccd84b

File tree

11 files changed

+985
-30
lines changed

11 files changed

+985
-30
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ with a JSON payload of the form:
8282
// List of algorithms used to filter the data
8383
// - optional, defaults to no filters
8484
"filters": [{"id": "shuffle", "element_size": 4}],
85+
86+
// Missing data description
87+
// - optional, defaults to no missing data
88+
// - exactly one of the keys below should be specified
89+
// - the values should match the data type (dtype)
90+
"missing": {
91+
"missing_value": 42,
92+
"missing_values": [42, -42],
93+
"valid_min": 42,
94+
"valid_max": 42,
95+
"valid_range": [-42, 42],
96+
}
8597
}
8698
```
8799

scripts/client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,23 @@ def get_args() -> argparse.Namespace:
3838
parser.add_argument("--selection", type=str)
3939
parser.add_argument("--compression", type=str)
4040
parser.add_argument("--shuffle", action=argparse.BooleanOptionalAction)
41+
missing = parser.add_mutually_exclusive_group()
42+
missing.add_argument("--missing-value", type=str)
43+
missing.add_argument("--missing-values", type=str)
44+
missing.add_argument("--valid-min", type=str)
45+
missing.add_argument("--valid-max", type=str)
46+
missing.add_argument("--valid-range", type=str)
4147
parser.add_argument("--verbose", action=argparse.BooleanOptionalAction)
4248
return parser.parse_args()
4349

4450

51+
def parse_number(s: str):
52+
try:
53+
return int(s)
54+
except ValueError:
55+
return float(s)
56+
57+
4558
def build_request_data(args: argparse.Namespace) -> dict:
4659
request_data = {
4760
'source': args.source,
@@ -65,6 +78,20 @@ def build_request_data(args: argparse.Namespace) -> dict:
6578
filters.append({"id": "shuffle", "element_size": element_size})
6679
if filters:
6780
request_data["filters"] = filters
81+
missing = None
82+
if args.missing_value:
83+
missing = {"missing_value": parse_number(args.missing_value)}
84+
if args.missing_values:
85+
missing = {"missing_values": [parse_number(n) for n in args.missing_values.split(",")]}
86+
if args.valid_min:
87+
missing = {"valid_min": parse_number(args.valid_min)}
88+
if args.valid_max:
89+
missing = {"valid_max": parse_number(args.valid_max)}
90+
if args.valid_range:
91+
min, max = args.valid_range.split(",")
92+
missing = {"valid_range": [parse_number(min), parse_number(max)]}
93+
if missing:
94+
request_data["missing"] = missing
6895
return {k: v for k, v in request_data.items() if v is not None}
6996

7097

src/error.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use std::error::Error;
1616
use thiserror::Error;
1717
use tracing::{event, Level};
1818

19+
use crate::types::DValue;
20+
1921
/// Active Storage server error type
2022
///
2123
/// This type encapsulates the various errors that may occur.
@@ -34,6 +36,9 @@ pub enum ActiveStorageError {
3436
#[error("failed to convert from bytes to {type_name}")]
3537
FromBytes { type_name: &'static str },
3638

39+
#[error("Incompatible value {0} for missing")]
40+
IncompatibleMissing(DValue),
41+
3742
/// Error deserialising request data into RequestData
3843
#[error("request data is not valid")]
3944
RequestDataJsonRejection(#[from] JsonRejection),
@@ -184,6 +189,7 @@ impl From<ActiveStorageError> for ErrorResponse {
184189
// Bad request
185190
ActiveStorageError::Decompression(_)
186191
| ActiveStorageError::EmptyArray { operation: _ }
192+
| ActiveStorageError::IncompatibleMissing(_)
187193
| ActiveStorageError::RequestDataJsonRejection(_)
188194
| ActiveStorageError::RequestDataValidationSingle(_)
189195
| ActiveStorageError::RequestDataValidation(_)
@@ -345,6 +351,15 @@ mod tests {
345351
.await;
346352
}
347353

354+
#[tokio::test]
355+
async fn incompatible_missing() {
356+
let value = 32.into();
357+
let error = ActiveStorageError::IncompatibleMissing(value);
358+
let message = "Incompatible value 32 for missing";
359+
let caused_by = None;
360+
test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await;
361+
}
362+
348363
#[tokio::test]
349364
async fn request_data_validation_single() {
350365
let validation_error = validator::ValidationError::new("foo");

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ pub mod server;
3838
#[cfg(test)]
3939
pub mod test_utils;
4040
pub mod tracing;
41+
pub mod types;
4142
pub mod validated_json;

src/models.rs

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use strum_macros::Display;
66
use url::Url;
77
use validator::{Validate, ValidationError};
88

9+
use crate::types::{DValue, Missing};
10+
911
/// Supported numerical data types
1012
#[derive(Clone, Copy, Debug, Deserialize, Display, PartialEq)]
1113
#[serde(rename_all = "lowercase")]
@@ -14,11 +16,11 @@ pub enum DType {
1416
Int32,
1517
/// [i64]
1618
Int64,
17-
/// [u64]
19+
/// [u32]
1820
Uint32,
1921
/// [u64]
2022
Uint64,
21-
/// [f64]
23+
/// [f32]
2224
Float32,
2325
/// [f64]
2426
Float64,
@@ -142,6 +144,8 @@ pub struct RequestData {
142144
pub compression: Option<Compression>,
143145
/// List of filter algorithms
144146
pub filters: Option<Vec<Filter>>,
147+
/// Missing data
148+
pub missing: Option<Missing<DValue>>,
145149
}
146150

147151
/// Validate an array shape
@@ -230,6 +234,9 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
230234
}
231235
_ => (),
232236
};
237+
if let Some(missing) = &request_data.missing {
238+
missing.validate(request_data.dtype)?;
239+
};
233240
Ok(())
234241
}
235242

@@ -359,6 +366,11 @@ mod tests {
359366
Token::U32(4),
360367
Token::MapEnd,
361368
Token::SeqEnd,
369+
Token::Str("missing"),
370+
Token::Some,
371+
Token::Enum { name: "Missing" },
372+
Token::Str("missing_value"),
373+
Token::I32(42),
362374
Token::StructEnd,
363375
],
364376
);
@@ -664,14 +676,40 @@ mod tests {
664676
)
665677
}
666678

679+
#[test]
680+
fn test_invalid_missing() {
681+
assert_de_tokens_error::<RequestData>(
682+
&[
683+
Token::Struct {
684+
name: "RequestData",
685+
len: 2,
686+
},
687+
Token::Str("missing"),
688+
Token::Some,
689+
Token::Enum { name: "Missing" },
690+
Token::Str("foo"),
691+
Token::StructEnd
692+
],
693+
"unknown variant `foo`, expected one of `missing_value`, `missing_values`, `valid_min`, `valid_max`, `valid_range`",
694+
)
695+
}
696+
697+
#[test]
698+
#[should_panic(expected = "Incompatible value 9223372036854775807 for missing")]
699+
fn test_missing_invalid_value_for_dtype() {
700+
let mut request_data = test_utils::get_test_request_data();
701+
request_data.missing = Some(Missing::MissingValue(i64::max_value().into()));
702+
request_data.validate().unwrap()
703+
}
704+
667705
#[test]
668706
fn test_unknown_field() {
669707
assert_de_tokens_error::<RequestData>(&[
670708
Token::Struct { name: "RequestData", len: 2 },
671709
Token::Str("foo"),
672710
Token::StructEnd
673711
],
674-
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `offset`, `size`, `shape`, `order`, `selection`, `compression`, `filters`"
712+
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `offset`, `size`, `shape`, `order`, `selection`, `compression`, `filters`, `missing`"
675713
)
676714
}
677715

@@ -686,8 +724,82 @@ mod tests {
686724

687725
#[test]
688726
fn test_json_optional_fields() {
689-
let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [2, 5], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]], "compression": {"id": "gzip"}, "filters": [{"id": "shuffle", "element_size": 4}]}"#;
727+
let json = r#"{
728+
"source": "http://example.com",
729+
"bucket": "bar",
730+
"object": "baz",
731+
"dtype": "int32",
732+
"offset": 4,
733+
"size": 8,
734+
"shape": [2, 5],
735+
"order": "C",
736+
"selection": [[1, 2, 3], [4, 5, 6]],
737+
"compression": {"id": "gzip"},
738+
"filters": [{"id": "shuffle", "element_size": 4}],
739+
"missing": {"missing_value": 42}
740+
}"#;
690741
let request_data = serde_json::from_str::<RequestData>(json).unwrap();
691742
assert_eq!(request_data, test_utils::get_test_request_data_optional());
692743
}
744+
745+
#[test]
746+
fn test_json_optional_fields2() {
747+
let json = r#"{
748+
"source": "http://example.com",
749+
"bucket": "bar",
750+
"object": "baz",
751+
"dtype": "float64",
752+
"offset": 4,
753+
"size": 8,
754+
"shape": [2, 5, 10],
755+
"order": "F",
756+
"selection": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
757+
"compression": {"id": "zlib"},
758+
"filters": [{"id": "shuffle", "element_size": 8}],
759+
"missing": {"valid_range": [-1.0, 999.0]}
760+
}"#;
761+
let request_data = serde_json::from_str::<RequestData>(json).unwrap();
762+
let mut expected = test_utils::get_test_request_data_optional();
763+
expected.dtype = DType::Float64;
764+
expected.shape = Some(vec![2, 5, 10]);
765+
expected.order = Some(Order::F);
766+
expected.selection = Some(vec![
767+
Slice::new(1, 2, 3),
768+
Slice::new(4, 5, 6),
769+
Slice::new(7, 8, 9),
770+
]);
771+
expected.compression = Some(Compression::Zlib);
772+
expected.filters = Some(vec![Filter::Shuffle { element_size: 8 }]);
773+
expected.missing = Some(Missing::ValidRange(
774+
DValue::from_f64(-1.0).unwrap(),
775+
DValue::from_f64(999.0).unwrap(),
776+
));
777+
assert_eq!(request_data, expected);
778+
}
779+
780+
#[test]
781+
fn test_json_optional_fields3() {
782+
let json = format!(
783+
r#"{{
784+
"source": "http://example.com",
785+
"bucket": "bar",
786+
"object": "baz",
787+
"dtype": "int32",
788+
"missing": {{"missing_values": [{}, -1, 0, 1, {}]}}
789+
}}"#,
790+
i64::min_value(),
791+
u64::max_value()
792+
);
793+
let request_data = serde_json::from_str::<RequestData>(&json).unwrap();
794+
let mut expected = test_utils::get_test_request_data();
795+
expected.dtype = DType::Int32;
796+
expected.missing = Some(Missing::MissingValues(vec![
797+
i64::min_value().into(),
798+
(-1).into(),
799+
0.into(),
800+
1.into(),
801+
u64::max_value().into(),
802+
]));
803+
assert_eq!(request_data, expected);
804+
}
693805
}

src/operation.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use crate::error::ActiveStorageError;
44
use crate::models;
5+
use crate::types::dvalue::TryFromDValue;
56

67
use axum::body::Bytes;
78

@@ -12,9 +13,12 @@ pub trait Element:
1213
+ PartialOrd
1314
+ num_traits::FromPrimitive
1415
+ num_traits::Zero
16+
+ std::convert::From<u16>
1517
+ std::fmt::Debug
18+
+ std::iter::Sum
1619
+ std::ops::Add<Output = Self>
1720
+ std::ops::Div<Output = Self>
21+
+ TryFromDValue
1822
+ zerocopy::AsBytes
1923
+ zerocopy::FromBytes
2024
{
@@ -28,9 +32,12 @@ impl<T> Element for T where
2832
+ num_traits::FromPrimitive
2933
+ num_traits::One
3034
+ num_traits::Zero
35+
+ std::convert::From<u16>
3136
+ std::fmt::Debug
37+
+ std::iter::Sum
3238
+ std::ops::Add<Output = Self>
3339
+ std::ops::Div<Output = Self>
40+
+ TryFromDValue
3441
+ zerocopy::AsBytes
3542
+ zerocopy::FromBytes
3643
{

0 commit comments

Comments
 (0)