From 83689af31e3fdf19e6d9ef57100fa2872870de3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=BCller?= Date: Fri, 3 Feb 2023 16:40:10 +0100 Subject: [PATCH] Implement the library MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marcel Müller --- Cargo.lock | 79 +++++++++++++ Cargo.toml | 6 + src/error.rs | 18 +++ src/lib.rs | 215 +++++++++++++++++++++++++++++++-- src/value.rs | 329 +++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 640 insertions(+), 7 deletions(-) create mode 100644 src/error.rs create mode 100644 src/value.rs diff --git a/Cargo.lock b/Cargo.lock index 4e607b4..fbf7ca1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,3 +5,82 @@ version = 3 [[package]] name = "envious" version = "0.1.0" +dependencies = [ + "serde", + "thiserror", +] + +[[package]] +name = "proc-macro2" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "syn" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" diff --git a/Cargo.toml b/Cargo.toml index 4fd9d9e..a3b4314 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,13 @@ version = "0.1.0" edition = "2021" license = "MIT OR Apache-2.0" description = "Deserialize (potentially nested) environment variables into your custom structs" +resolver = "2" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +serde = "1.0.152" +thiserror = "1.0.38" + +[dev-dependencies] +serde = { version = "1.0.152", features = ["derive"] } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..ffa87bf --- /dev/null +++ b/src/error.rs @@ -0,0 +1,18 @@ +#[derive(Debug, PartialEq, thiserror::Error)] +pub enum EnvDeserializationError { + #[error("An error occured during deserialization: {}", .0)] + GenericDeserialization(String), + #[error("An unsupported variant was tried to be deserialized. Only structs and maps are currently supported.")] + UnsupportedValue, + #[error("Tried to nest values while a simple value was expected")] + InvalidNestedValues, +} + +impl serde::de::Error for EnvDeserializationError { + fn custom(msg: T) -> Self + where + T: std::fmt::Display, + { + EnvDeserializationError::GenericDeserialization(msg.to_string()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 7d12d9a..01bbc1c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,215 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right +use error::EnvDeserializationError; +use serde::de::{value::StringDeserializer, DeserializeOwned, IntoDeserializer}; +use value::Value; + +mod error; +mod value; + +#[derive(Debug, PartialEq)] +struct Key { + original: String, + current: String, +} + +impl AsRef for Key { + fn as_ref(&self) -> &str { + &self.current + } +} + +impl Key { + fn new(original: String) -> Self { + Self { + current: original.clone(), + original, + } + } +} + +impl<'de> IntoDeserializer<'de, EnvDeserializationError> for Key { + type Deserializer = StringDeserializer; + fn into_deserializer(self) -> Self::Deserializer { + self.current.into_deserializer() + } +} + +pub enum Prefix<'a> { + None, + Some(&'a str), +} + +pub fn from_env( + prefix: Prefix<'_>, +) -> Result { + let env_values = std::env::vars(); + + from_primitive(env_values.flat_map(|(key, value)| { + if let Prefix::Some(prefix) = prefix { + let stripped_key = key.strip_prefix(prefix).map(|s| s.to_string())?; + Some((Key::new(stripped_key), value)) + } else { + Some((Key::new(key), value)) + } + })) +} + +fn from_primitive>( + values: I, +) -> Result { + let deserializer = + Value::from_list(values.map(|(key, val)| (key, Value::Simple(val))).collect()).unwrap(); + T::deserialize(deserializer) } #[cfg(test)] -mod tests { - use super::*; +mod test { + use serde::Deserialize; + + use crate::{from_primitive, Key}; + + #[test] + fn check_simple_struct() { + #[derive(Debug, PartialEq, Deserialize)] + struct Simple { + allowed: bool, + } + + let expected = Simple { allowed: true }; + + let actual: Simple = + from_primitive([(Key::new(String::from("allowed")), String::from("true"))].into_iter()) + .unwrap(); + + assert_eq!(actual, expected); + } #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + fn check_double_nested_struct() { + #[derive(Debug, PartialEq, Deserialize)] + struct InnerExtraConfig { + allowed: bool, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct InnerConfig { + smoothness: f32, + extra: InnerExtraConfig, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Nested { + temp: u64, + inner: InnerConfig, + } + + let expected = Nested { + temp: 15, + inner: InnerConfig { + smoothness: 32.0, + extra: InnerExtraConfig { allowed: false }, + }, + }; + + let actual: Nested = from_primitive( + [ + (Key::new(String::from("temp")), String::from("15")), + ( + Key::new(String::from("inner__smoothness")), + String::from("32.0"), + ), + ( + Key::new(String::from("inner__extra__allowed")), + String::from("false"), + ), + ] + .into_iter(), + ) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + fn check_renamed_struct() { + #[derive(Debug, PartialEq, Deserialize)] + #[serde(rename_all = "SCREAMING-KEBAB-CASE")] + struct Simple { + allowed_simply: bool, + } + + let expected = Simple { + allowed_simply: true, + }; + + let actual: Simple = from_primitive( + [( + Key::new(String::from("ALLOWED-SIMPLY")), + String::from("true"), + )] + .into_iter(), + ) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + fn check_simple_enum() { + #[derive(Debug, PartialEq, Deserialize)] + enum Simple { + Yes, + No, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct SimpleEnum { + simple: Simple, + } + + let expected = SimpleEnum { simple: Simple::No }; + + let actual: SimpleEnum = + from_primitive([(Key::new(String::from("simple")), String::from("No"))].into_iter()) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + fn check_complex_enum() { + #[derive(Debug, PartialEq, Deserialize)] + enum Complex { + Access { password: String, foo: f32 }, + No, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct ComplexEnum { + complex: Complex, + } + + let expected = ComplexEnum { + complex: Complex::Access { + password: String::from("hunter2"), + foo: 42.0, + }, + }; + + let actual: ComplexEnum = from_primitive( + [ + ( + Key::new(String::from("complex__Access__password")), + String::from("hunter2"), + ), + ( + Key::new(String::from("complex__Access__foo")), + String::from("42.0"), + ), + ] + .into_iter(), + ) + .unwrap(); + + assert_eq!(actual, expected); } } diff --git a/src/value.rs b/src/value.rs new file mode 100644 index 0000000..c25fdb8 --- /dev/null +++ b/src/value.rs @@ -0,0 +1,329 @@ +use serde::de::value::{MapAccessDeserializer, MapDeserializer, SeqDeserializer}; +use serde::de::IntoDeserializer; +use serde::Deserializer; + +use crate::error::EnvDeserializationError; +use crate::Key; + +#[derive(Debug, PartialEq)] +pub(crate) enum Value { + Simple(String), + Map(Vec<(Key, Value)>), +} + +const SEPERATOR: &str = "__"; + +impl Value { + fn insert_at(&mut self, path: &[&str], value: Value) -> Result<(), ()> { + match self { + Value::Simple(_) => Err(()), + Value::Map(values) => { + let val = if let Some((_key, val)) = + values.iter_mut().find(|(key, _)| key.as_ref() == path[0]) + { + match val { + Value::Simple(_) => return Err(()), + Value::Map(_) => val, + } + } else { + let val = Value::Map(vec![]); + values.push((Key::new(path[0].to_string()), val)); + &mut values.last_mut().unwrap().1 + }; + + let path = &path[1..]; + + if path.len() > 1 { + val.insert_at(path, value) + } else { + match val { + Value::Simple(_) => return Err(()), + Value::Map(values) => values.push((Key::new(path[0].to_string()), value)), + } + Ok(()) + } + } + } + } + + pub(crate) fn from_list(list: Vec<(Key, Value)>) -> Result { + let mut base = Value::Map(vec![]); + + for (key, value) in list { + let key_str = key.as_ref(); + let path = key_str.split(SEPERATOR).collect::>(); + + if path.len() == 1 { + if let Value::Map(base) = &mut base { + base.push((key, value)); + } else { + unreachable!() + } + } else { + base.insert_at(&path, value)?; + } + } + + Ok(base) + } +} + +macro_rules! forward_to_deserializer { + ($($ty:ident => $method:ident),* $(,)?) => { + $( + fn $method(self, visitor: V) -> Result + where V: serde::de::Visitor<'de> + { + match self { + Value::Simple(val) => { + match val.parse::<$ty>() { + Ok(val) => val.into_deserializer().$method(visitor), + Err(e) => Err(crate::error::EnvDeserializationError::GenericDeserialization(format!("'{}' could not be deserialized due to: {}", val, e))), + } + } + Value::Map(_) => Err(crate::error::EnvDeserializationError::InvalidNestedValues) + } + } + )* + }; +} + +impl<'de> IntoDeserializer<'de, EnvDeserializationError> for Value { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +impl<'de> Deserializer<'de> for Value { + type Error = crate::error::EnvDeserializationError; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + Value::Simple(val) => val.into_deserializer().deserialize_any(visitor), + Value::Map(_) => self.deserialize_map(visitor), + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + val @ Value::Simple(_) => { + SeqDeserializer::new(std::iter::once(val)).deserialize_seq(visitor) + } + Value::Map(values) => { + let values = values.into_iter().map(|(_, val)| val); + + SeqDeserializer::new(values).deserialize_seq(visitor) + } + } + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + Value::Simple(val) => visitor.visit_enum(val.into_deserializer()), + Value::Map(values) => visitor.visit_enum(MapAccessDeserializer::new( + MapDeserializer::new(values.into_iter()), + )), + } + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self { + Value::Simple(_) => Err(EnvDeserializationError::UnsupportedValue), + Value::Map(values) => visitor.visit_map(MapDeserializer::new(values.into_iter())), + } + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_map(visitor) + } + + forward_to_deserializer! { + u8 => deserialize_u8, + i8 => deserialize_i8, + u16 => deserialize_u16, + i16 => deserialize_i16, + u32 => deserialize_u32, + i32 => deserialize_i32, + u64 => deserialize_u64, + i64 => deserialize_i64, + f32 => deserialize_f32, + f64 => deserialize_f64, + bool => deserialize_bool, + } + + serde::forward_to_deserialize_any! { + char str string bytes byte_buf unit unit_struct tuple_struct + identifier tuple ignored_any + } +} + +#[cfg(test)] +mod tests { + use crate::Key; + + use super::Value; + use serde::Deserialize; + + #[test] + fn simple_values() { + assert_eq!( + Ok(true), + <_>::deserialize(Value::Simple(String::from("true"))) + ); + + assert_eq!( + Ok(25u32), + <_>::deserialize(Value::Simple(String::from("25"))) + ); + assert_eq!( + Ok(String::from("foobar")), + <_>::deserialize(Value::Simple(String::from("foobar"))) + ); + assert_eq!( + Ok(Some(String::from("foobar"))), + <_>::deserialize(Value::Simple(String::from("foobar"))) + ); + } + + #[test] + fn simple_sequence() { + assert_eq!( + Ok(vec![125u32]), + <_>::deserialize(Value::Simple(String::from("125"))) + ); + assert_eq!( + Ok(vec![125u32, 200, 300]), + <_>::deserialize(Value::Map(vec![ + (Key::new(String::new()), Value::Simple(String::from("125"))), + (Key::new(String::new()), Value::Simple(String::from("200"))), + (Key::new(String::new()), Value::Simple(String::from("300"))) + ])) + ); + } + + #[test] + fn simple_map() { + assert_eq!( + Ok(std::collections::HashMap::from([( + String::from("foo"), + 123 + )])), + <_>::deserialize(Value::Map(vec![( + Key::new(String::from("foo")), + Value::Simple(String::from("123")) + ),])) + ); + + assert_eq!( + Ok(std::collections::HashMap::from([( + String::from("foo"), + std::collections::HashMap::from([(String::from("bar"), 123)]), + )])), + <_>::deserialize(Value::Map(vec![( + Key::new(String::from("foo")), + Value::Map(vec![( + Key::new(String::from("bar")), + Value::Simple(String::from("123")) + ),]) + ),])) + ); + } + + #[test] + fn convert_list_of_key_vals_to_tree() { + let input = vec![ + ( + Key::new(String::from("FOO")), + Value::Simple(String::from("bar")), + ), + ( + Key::new(String::from("BAZ")), + Value::Simple(String::from("124")), + ), + ( + Key::new(String::from("NESTED__FOO")), + Value::Simple(String::from("true")), + ), + ( + Key::new(String::from("NESTED__BAZ")), + Value::Simple(String::from("Hello")), + ), + ]; + + let expected = Value::Map(vec![ + ( + Key::new(String::from("FOO")), + Value::Simple(String::from("bar")), + ), + ( + Key::new(String::from("BAZ")), + Value::Simple(String::from("124")), + ), + ( + Key::new(String::from("NESTED")), + Value::Map(vec![ + ( + Key::new(String::from("FOO")), + Value::Simple(String::from("true")), + ), + ( + Key::new(String::from("BAZ")), + Value::Simple(String::from("Hello")), + ), + ]), + ), + ]); + + let actual = Value::from_list(input).unwrap(); + + dbg!(&expected); + dbg!(&actual); + + assert_eq!(actual, expected); + } +}