diff --git a/Cargo.toml b/Cargo.toml index 866c31318..a7c26d9d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ itoa = "1.0" memchr = { version = "2", default-features = false } ryu = "1.0" serde = { version = "1.0.194", default-features = false } +serde_spanned = { version = "0.6.8", features = ["serde"], optional = true } [dev-dependencies] automod = "1.0.11" @@ -89,3 +90,7 @@ raw_value = [] # overflow the stack after deserialization has completed, including, but not # limited to, Display and Debug and Drop impls. unbounded_depth = [] + +# Provide deserialization of serde_spanned::Spanned, which tracks the position +# of the value in the original JSON document +spanned = ["serde_spanned"] diff --git a/src/de.rs b/src/de.rs index 4080c54ac..b8b89c2b4 100644 --- a/src/de.rs +++ b/src/de.rs @@ -67,6 +67,11 @@ where disable_recursion_limit: false, } } + + #[cfg(feature = "spanned")] + pub(crate) fn byte_offset(&self) -> usize { + self.read.byte_offset() + } } #[cfg(feature = "std")] @@ -1817,13 +1822,16 @@ impl<'de, R: Read<'de>> de::Deserializer<'de> for &mut Deserializer { fn deserialize_struct( self, - _name: &'static str, - _fields: &'static [&'static str], + name: &'static str, + fields: &'static [&'static str], visitor: V, ) -> Result where V: de::Visitor<'de>, { + let _ = name; + let _ = fields; + let peek = match tri!(self.parse_whitespace()) { Some(b) => b, None => { @@ -1831,6 +1839,11 @@ impl<'de, R: Read<'de>> de::Deserializer<'de> for &mut Deserializer { } }; + #[cfg(feature = "spanned")] + if serde_spanned::__unstable::is_spanned(name, fields) { + return visitor.visit_map(crate::spanned::SpannedDeserializer::new(self)); + } + let value = match peek { b'[' => { check_recursion! { diff --git a/src/lib.rs b/src/lib.rs index 5f066de26..400763ac1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -436,3 +436,6 @@ mod read; #[cfg(feature = "raw_value")] mod raw; + +#[cfg(feature = "spanned")] +pub(crate) mod spanned; diff --git a/src/spanned.rs b/src/spanned.rs new file mode 100644 index 000000000..19373292c --- /dev/null +++ b/src/spanned.rs @@ -0,0 +1,87 @@ +use crate::de::{Deserializer, Read}; +use serde::de::value::BorrowedStrDeserializer; +use serde::de::IntoDeserializer as _; + +pub(crate) enum SpannedDeserializer<'d, R> { + Start { + value_deserializer: &'d mut Deserializer, + }, + Value { + value_deserializer: &'d mut Deserializer, + }, + End { + end_pos: usize, + }, + Done, +} + +impl<'d, R> SpannedDeserializer<'d, R> { + pub fn new(value_deserializer: &'d mut Deserializer) -> Self { + Self::Start { value_deserializer } + } +} + +impl<'d, 'de, R> serde::de::MapAccess<'de> for SpannedDeserializer<'d, R> +where + R: Read<'de>, +{ + type Error = <&'d mut Deserializer as serde::de::Deserializer<'de>>::Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: serde::de::DeserializeSeed<'de>, + { + let key = match self { + Self::Start { .. } => serde_spanned::__unstable::START_FIELD, + Self::End { .. } => serde_spanned::__unstable::END_FIELD, + Self::Value { .. } => serde_spanned::__unstable::VALUE_FIELD, + Self::Done => return Ok(None), + }; + + seed.deserialize(BorrowedStrDeserializer::new(key)) + .map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + match self { + Self::Start { .. } => { + let prev = std::mem::replace(self, Self::Done); + let Self::Start { value_deserializer } = prev else { + unreachable!() + }; + + let start = value_deserializer.byte_offset(); + *self = Self::Value { value_deserializer }; + seed.deserialize(start.into_deserializer()) + } + + Self::Value { .. } => { + let prev = std::mem::replace(self, Self::Done); + let Self::Value { value_deserializer } = prev else { + unreachable!() + }; + + let val = seed.deserialize(&mut *value_deserializer); + *self = Self::End { + end_pos: value_deserializer.byte_offset(), + }; + val + } + + Self::End { .. } => { + let prev = std::mem::replace(self, Self::Done); + let Self::End { end_pos } = prev else { + unreachable!() + }; + seed.deserialize(end_pos.into_deserializer()) + } + + Self::Done => { + panic!("should not get here"); + } + } + } +} diff --git a/tests/crate/Cargo.toml b/tests/crate/Cargo.toml index e13df6a8e..9f6c48277 100644 --- a/tests/crate/Cargo.toml +++ b/tests/crate/Cargo.toml @@ -20,3 +20,4 @@ float_roundtrip = ["serde_json/float_roundtrip"] arbitrary_precision = ["serde_json/arbitrary_precision"] raw_value = ["serde_json/raw_value"] unbounded_depth = ["serde_json/unbounded_depth"] +serde_spanned = ["serde_json/serde_spanned"] diff --git a/tests/test.rs b/tests/test.rs index d41a2336a..ef57c0413 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -33,7 +33,7 @@ use serde_json::{ to_vec, Deserializer, Number, Value, }; use std::collections::BTreeMap; -#[cfg(feature = "raw_value")] +#[cfg(any(feature = "raw_value", feature = "spanned"))] use std::collections::HashMap; use std::fmt::{self, Debug}; use std::hash::BuildHasher; @@ -43,6 +43,8 @@ use std::io; use std::iter; use std::marker::PhantomData; use std::mem; +#[cfg(feature = "spanned")] +use std::ops::Range; use std::str::FromStr; use std::{f32, f64}; @@ -2558,3 +2560,138 @@ fn test_control_character_search() { "control character (\\u0000-\\u001F) found while parsing a string at line 1 column 2", )]); } + +#[cfg(feature = "spanned")] +fn format_span(json: &str, span: Range) -> String { + format!( + "{}\n{}{}", + json, + " ".repeat(span.start), + "^".repeat(span.end - span.start) + ) +} + +#[cfg(feature = "spanned")] +#[track_caller] +fn assert_span_eq(json: &str, expected: Range, actual: Range) { + let expected_str = format_span(json, expected.clone()); + let actual_str = format_span(json, actual.clone()); + + assert_eq!( + expected, actual, + "Expected span:\n{}\nActual span:\n{}", + expected_str, actual_str + ); +} + +#[cfg(feature = "spanned")] +#[test] +fn test_spanned_string() { + use serde_spanned::Spanned; + + #[derive(Deserialize)] + struct SpannedStruct { + field: Spanned, + } + + let json = r#"{"field": "value"}"#; + let result: SpannedStruct = serde_json::from_str(json).unwrap(); + assert_eq!(result.field.as_ref(), "value"); + assert_span_eq(json, result.field.span(), 10..17); +} + +#[cfg(feature = "spanned")] +#[test] +fn test_spanned_number() { + use serde_spanned::Spanned; + + #[derive(Deserialize)] + struct SpannedStruct { + field: Spanned, + } + + let json = r#"{"field": -2.718e28}"#; + let result: SpannedStruct = serde_json::from_str(json).unwrap(); + assert_eq!(*result.field.as_ref(), -2.718e28); + assert_span_eq(json, result.field.span(), 10..19); +} + +#[cfg(feature = "spanned")] +#[test] +fn test_spanned_whole_array() { + use serde_spanned::Spanned; + + #[derive(Deserialize)] + struct SpannedStruct { + field: Spanned>, + } + + let json = r#"{"field": [1, 2, 3, 4]}"#; + let result: SpannedStruct = serde_json::from_str(json).unwrap(); + assert_eq!(result.field.as_ref(), &[1, 2, 3, 4]); + assert_span_eq(json, result.field.span(), 10..22); +} + +#[cfg(feature = "spanned")] +#[test] +fn test_spanned_array_items() { + use serde_spanned::Spanned; + + #[derive(Deserialize)] + struct SpannedStruct { + field: Vec>, + } + + let json = r#"{"field": [1, 2, 3, 4]}"#; + let result: SpannedStruct = serde_json::from_str(json).unwrap(); + assert_eq!(result.field.len(), 4); + + assert_eq!(*result.field[0].as_ref(), 1); + assert_span_eq(json, result.field[0].span(), 11..12); + + assert_eq!(*result.field[1].as_ref(), 2); + assert_span_eq(json, result.field[1].span(), 14..15); + + assert_eq!(*result.field[2].as_ref(), 3); + assert_span_eq(json, result.field[2].span(), 17..18); + + assert_eq!(*result.field[3].as_ref(), 4); + assert_span_eq(json, result.field[3].span(), 20..21); +} + +#[cfg(feature = "spanned")] +#[test] +fn test_spanned_whole_map() { + use serde_spanned::Spanned; + + #[derive(Deserialize)] + struct SpannedStruct { + field: Spanned>, + } + + let json = r#"{"field": {"1": "one", "2": "two"}}"#; + let result: SpannedStruct = serde_json::from_str(json).unwrap(); + assert_span_eq(json, result.field.span(), 10..34); + let mut map = result.field.into_inner(); + let one = map.remove(&1).unwrap(); + assert_eq!(one, "one"); + let two = map.remove(&2).unwrap(); + assert_eq!(two, "two"); + assert!(map.is_empty()); +} + +#[cfg(feature = "spanned")] +#[test] +fn test_spanned_whitespace() { + use serde_spanned::Spanned; + + #[derive(Deserialize)] + struct SpannedStruct { + field: Spanned, + } + + let json = r#"{"field": -2.718e28 }"#; + let result: SpannedStruct = serde_json::from_str(json).unwrap(); + assert_eq!(*result.field.as_ref(), -2.718e28); + assert_span_eq(json, result.field.span(), 15..24); +}