diff --git a/Cargo.lock b/Cargo.lock --- a/Cargo.lock +++ b/Cargo.lock @@ -5639,6 +5639,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "util_macros" +version = "0.1.0" +dependencies = [ + "quote", + "serde", + "serde_json", + "syn 1.0.109", +] + [[package]] name = "uuid" version = "1.8.0" diff --git a/shared/util_macros/Cargo.toml b/shared/util_macros/Cargo.toml new file mode 100644 --- /dev/null +++ b/shared/util_macros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "util_macros" +version = "0.1.0" +edition.workspace = true +license.workspace = true +homepage.workspace = true + +[lib] +proc-macro = true + +[dependencies] +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +syn = "1.0" +quote = "1.0" diff --git a/shared/util_macros/src/lib.rs b/shared/util_macros/src/lib.rs new file mode 100644 --- /dev/null +++ b/shared/util_macros/src/lib.rs @@ -0,0 +1,62 @@ +extern crate proc_macro; +use proc_macro::TokenStream; + +mod tag_aware_deserialize; + +/// Fixes [`serde::Deserialize`] implementation for deserializing +/// untagged enums of internally tagged structs. Original implementation +/// totally ignores the `tag` attribute when deserializing enum variants. +/// +/// This derive requires two serde attributes to be present: +/// `#[serde(tag = "type", remote = "Self")]` +/// +/// ### Example +/// ``` +/// use serde::{Deserialize, Serialize}; +/// use util_macros::TagAwareDeserialize; +/// +/// // Note that FirstStruct and SecondStruct have identical fields +/// // They're only distinguishable by 'tag' +/// #[derive(Debug, Serialize, Deserialize, TagAwareDeserialize)] +/// #[serde(tag = "type", remote = "Self")] +/// struct FirstStruct { +/// foo: String, +/// bar: String, +/// } +/// #[derive(Debug, Serialize, Deserialize, TagAwareDeserialize)] +/// #[serde(tag = "type", remote = "Self")] +/// struct SecondStruct { +/// foo: String, +/// bar: String, +/// } +/// +/// #[derive(Debug, Serialize, Deserialize)] +/// #[serde(untagged)] +/// enum EitherStruct { +/// // Note that FirstStruct is BEFORE SecondStruct +/// FirstStruct(FirstStruct), +/// SecondStruct(SecondStruct), +/// } +/// +/// let input = SecondStruct { +/// foo: "a".to_string(), +/// bar: "b".to_string(), +/// }; +/// let serialized = serde_json::to_string(&input).unwrap(); +/// +/// let deserialized: EitherStruct = serde_json::from_str(&serialized).unwrap(); +/// match deserialized { +/// EitherStruct::SecondStruct(result) => { +/// println!("Successfully deserialized {:?}", result); +/// }, +/// other => { +/// println!("Wrong variant was deserialized: {:?}", other); +/// } +/// }; +/// +/// ``` +#[proc_macro_derive(TagAwareDeserialize)] +pub fn tag_aware_deserialize_macro_derive(input: TokenStream) -> TokenStream { + let ast = syn::parse(input).unwrap(); + tag_aware_deserialize::impl_tag_aware_deserialize_macro(&ast) +} diff --git a/shared/util_macros/src/tag_aware_deserialize.rs b/shared/util_macros/src/tag_aware_deserialize.rs new file mode 100644 --- /dev/null +++ b/shared/util_macros/src/tag_aware_deserialize.rs @@ -0,0 +1,40 @@ +use proc_macro::TokenStream; +use quote::quote; + +pub fn impl_tag_aware_deserialize_macro(ast: &syn::DeriveInput) -> TokenStream { + let name = &ast.ident; + + let gen = quote! { + impl<'de> serde::de::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + let this = serde_json::Value::deserialize(deserializer)?; + if let Some(found_tag) = this.get("type") { + if found_tag == stringify!(#name) { + // now we can run _original_ deserialize + return match #name::deserialize(this) { + Ok(object) => Ok(object), + Err(err) => Err(Error::custom(err)), + }; + } + } + Err(Error::custom(format!("Deserialized object is not a '{}'", stringify!(#name)))) + } + } + + // We need a dummy `Serialize` impl for `remote = "Self"` to work + impl Serialize for #name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + Self::serialize(self, serializer) + } + } + }; + gen.into() +}