use serde::{
de::{
DeserializeOwned, DeserializeSeed, Deserializer, Error as SerdeError, MapAccess, SeqAccess, Visitor,
},
forward_to_deserialize_any,
};
use std::{
cmp::{Eq, Ordering},
collections::BTreeMap,
hash::{Hash, Hasher},
};
use crate::de::{Error as RonError, Result};
#[derive(Copy, Clone, Debug, PartialOrd, PartialEq)]
pub struct Number(f64);
impl Number {
pub fn new(v: f64) -> Self {
if !v.is_finite() {
panic!("Tried to create Number with a NaN / infinity");
}
Number(v)
}
pub fn get(self) -> f64 {
self.0
}
}
impl Eq for Number {}
impl Hash for Number {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.0 as u64);
}
}
impl Ord for Number {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).expect("Bug: Contract violation")
}
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Value {
Bool(bool),
Char(char),
Map(BTreeMap<Value, Value>),
Number(Number),
Option(Option<Box<Value>>),
String(String),
Seq(Vec<Value>),
Unit,
}
impl Value {
pub fn into_rust<T>(self) -> Result<T>
where
T: DeserializeOwned,
{
T::deserialize(self)
}
}
impl<'de> Deserializer<'de> for Value {
type Error = RonError;
forward_to_deserialize_any! {
bool f32 f64 char str string bytes
byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self {
Value::Bool(b) => visitor.visit_bool(b),
Value::Char(c) => visitor.visit_char(c),
Value::Map(m) => visitor.visit_map(Map {
keys: m.keys().cloned().rev().collect(),
values: m.values().cloned().rev().collect(),
}),
Value::Number(n) => visitor.visit_f64(n.get()),
Value::Option(Some(o)) => visitor.visit_some(*o),
Value::Option(None) => visitor.visit_none(),
Value::String(s) => visitor.visit_string(s),
Value::Seq(mut seq) => {
seq.reverse();
visitor.visit_seq(Seq { seq })
}
Value::Unit => visitor.visit_unit(),
}
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self {
Value::Number(n) => visitor.visit_i64(n.get() as i64),
v => Err(RonError::custom(format!("Expected a number, got {:?}", v))),
}
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_u64(visitor)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_u64(visitor)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_u64(visitor)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self {
Value::Number(n) => visitor.visit_u64(n.get() as u64),
v => Err(RonError::custom(format!("Expected a number, got {:?}", v))),
}
}
}
struct Map {
keys: Vec<Value>,
values: Vec<Value>,
}
impl<'de> MapAccess<'de> for Map {
type Error = RonError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>,
{
self.keys
.pop()
.map_or(Ok(None), |v| seed.deserialize(v).map(Some))
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
{
self.values
.pop()
.map(|v| seed.deserialize(v))
.expect("Contract violation")
}
}
struct Seq {
seq: Vec<Value>,
}
impl<'de> SeqAccess<'de> for Seq {
type Error = RonError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
self.seq
.pop()
.map_or(Ok(None), |v| seed.deserialize(v).map(Some))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use std::fmt::Debug;
fn assert_same<'de, T>(s: &'de str)
where
T: Debug + Deserialize<'de> + PartialEq,
{
use crate::de::from_str;
let direct: T = from_str(s).unwrap();
let value: Value = from_str(s).unwrap();
let value = T::deserialize(value).unwrap();
assert_eq!(direct, value, "Deserialization for {:?} is not the same", s);
}
#[test]
fn boolean() {
assert_same::<bool>("true");
assert_same::<bool>("false");
}
#[test]
fn float() {
assert_same::<f64>("0.123");
assert_same::<f64>("-4.19");
}
#[test]
fn int() {
assert_same::<u32>("626");
assert_same::<i32>("-50");
}
#[test]
fn char() {
assert_same::<char>("'4'");
assert_same::<char>("'c'");
}
#[test]
fn map() {
assert_same::<BTreeMap<char, String>>(
"{
'a': \"Hello\",
'b': \"Bye\",
}",
);
}
#[test]
fn option() {
assert_same::<Option<char>>("Some('a')");
assert_same::<Option<char>>("None");
}
#[test]
fn seq() {
assert_same::<Vec<f64>>("[1.0, 2.0, 3.0, 4.0]");
}
#[test]
fn unit() {
assert_same::<()>("()");
}
}