diff --git a/shared/util_macros/src/lib.rs b/shared/util_macros/src/lib.rs index 997b64d84..ed3bb2109 100644 --- a/shared/util_macros/src/lib.rs +++ b/shared/util_macros/src/lib.rs @@ -1,62 +1,64 @@ extern crate proc_macro; use proc_macro::TokenStream; +use syn::parse_macro_input; mod tag_aware_deserialize; /// Fixes [`serde::Deserialize`] implementation for deserializing /// untagged enums of internally tagged structs. Original implementation /// totally ignores the `tag` attribute when deserializing enum variants. /// /// This derive requires two serde attributes to be present: -/// `#[serde(tag = "type", remote = "Self")]` +/// `#[serde(tag = "<>", remote = "Self")]` /// /// ### Example /// ``` /// use serde::{Deserialize, Serialize}; /// use util_macros::TagAwareDeserialize; /// /// // Note that FirstStruct and SecondStruct have identical fields /// // They're only distinguishable by 'tag' /// #[derive(Debug, Serialize, Deserialize, TagAwareDeserialize)] /// #[serde(tag = "type", remote = "Self")] /// struct FirstStruct { /// foo: String, /// bar: String, /// } /// #[derive(Debug, Serialize, Deserialize, TagAwareDeserialize)] /// #[serde(tag = "type", remote = "Self")] /// struct SecondStruct { /// foo: String, /// bar: String, /// } /// /// #[derive(Debug, Serialize, Deserialize)] /// #[serde(untagged)] /// enum EitherStruct { /// // Note that FirstStruct is BEFORE SecondStruct /// FirstStruct(FirstStruct), /// SecondStruct(SecondStruct), /// } /// /// let input = SecondStruct { /// foo: "a".to_string(), /// bar: "b".to_string(), /// }; /// let serialized = serde_json::to_string(&input).unwrap(); /// /// let deserialized: EitherStruct = serde_json::from_str(&serialized).unwrap(); /// match deserialized { /// EitherStruct::SecondStruct(result) => { /// println!("Successfully deserialized {:?}", result); /// }, /// other => { /// println!("Wrong variant was deserialized: {:?}", other); /// } /// }; /// /// ``` #[proc_macro_derive(TagAwareDeserialize)] pub fn tag_aware_deserialize_macro_derive(input: TokenStream) -> TokenStream { - let ast = syn::parse(input).unwrap(); + let ast = parse_macro_input!(input as syn::DeriveInput); tag_aware_deserialize::impl_tag_aware_deserialize_macro(&ast) + .unwrap_or_else(|err| err.into_compile_error().into()) } diff --git a/shared/util_macros/src/tag_aware_deserialize.rs b/shared/util_macros/src/tag_aware_deserialize.rs index 8a22899af..90500d967 100644 --- a/shared/util_macros/src/tag_aware_deserialize.rs +++ b/shared/util_macros/src/tag_aware_deserialize.rs @@ -1,40 +1,92 @@ use proc_macro::TokenStream; use quote::quote; +use syn::{Attribute, Lit, Meta, NestedMeta, Result}; -pub fn impl_tag_aware_deserialize_macro(ast: &syn::DeriveInput) -> TokenStream { +pub fn impl_tag_aware_deserialize_macro( + ast: &syn::DeriveInput, +) -> Result { let name = &ast.ident; + if !has_serde_remote_self(&ast.attrs) { + return Err(syn::Error::new( + name.span(), + format!("{} must have `#[serde(remote = \"Self\")]` directive", name), + )); + } + + let Some(tag_value) = extract_tag_value(&ast.attrs) else { + return Err(syn::Error::new( + name.span(), + format!("{} must have `#[serde(tag = \"...\")]` directive", name), + )); + }; + let gen = quote! { impl<'de> serde::de::Deserialize<'de> for #name { fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { use serde::de::Error; let this = serde_json::Value::deserialize(deserializer)?; - if let Some(found_tag) = this.get("type") { + if let Some(found_tag) = this.get(#tag_value) { if found_tag == stringify!(#name) { // now we can run _original_ deserialize return match #name::deserialize(this) { Ok(object) => Ok(object), Err(err) => Err(Error::custom(err)), }; } } Err(Error::custom(format!("Deserialized object is not a '{}'", stringify!(#name)))) } } // We need a dummy `Serialize` impl for `remote = "Self"` to work impl Serialize for #name { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { Self::serialize(self, serializer) } } }; - gen.into() + Ok(gen.into()) +} + +/// Reads the `#[serde(tag = "...")]` value +fn extract_tag_value(attrs: &[Attribute]) -> Option { + let serde_attr = attrs.iter().find(|attr| attr.path.is_ident("serde"))?; + if let Ok(Meta::List(meta_list)) = serde_attr.parse_meta() { + for nested in meta_list.nested { + if let NestedMeta::Meta(Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("tag") { + if let Lit::Str(lit_str) = name_value.lit { + return Some(lit_str.value()); + } + } + } + } + } + None +} + +/// Checks for the `#[serde(remote = "Self")]` attribute +fn has_serde_remote_self(attrs: &[Attribute]) -> bool { + let Some(serde_attr) = attrs.iter().find(|attr| attr.path.is_ident("serde")) + else { + return false; + }; + if let Ok(Meta::List(meta_list)) = serde_attr.parse_meta() { + for nested in meta_list.nested { + if let NestedMeta::Meta(Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("remote") { + return matches!(name_value.lit, Lit::Str(lit_str) if lit_str.value() == "Self"); + } + } + } + } + false }