diff --git a/shared/util_macros/src/lib.rs b/shared/util_macros/src/lib.rs --- a/shared/util_macros/src/lib.rs +++ b/shared/util_macros/src/lib.rs @@ -1,5 +1,6 @@ extern crate proc_macro; use proc_macro::TokenStream; +use syn::parse_macro_input; mod tag_aware_deserialize; @@ -8,7 +9,7 @@ /// totally ignores the `tag` attribute when deserializing enum variants. /// /// This derive requires two serde attributes to be present: -/// `#[serde(tag = "type", remote = "Self")]` +/// `#[serde(tag = "<>", remote = "Self")]` /// /// ### Example /// ``` @@ -57,6 +58,7 @@ /// ``` #[proc_macro_derive(TagAwareDeserialize)] pub fn tag_aware_deserialize_macro_derive(input: TokenStream) -> TokenStream { - let ast = syn::parse(input).unwrap(); + let ast = parse_macro_input!(input as syn::DeriveInput); tag_aware_deserialize::impl_tag_aware_deserialize_macro(&ast) + .unwrap_or_else(|err| err.into_compile_error().into()) } diff --git a/shared/util_macros/src/tag_aware_deserialize.rs b/shared/util_macros/src/tag_aware_deserialize.rs --- a/shared/util_macros/src/tag_aware_deserialize.rs +++ b/shared/util_macros/src/tag_aware_deserialize.rs @@ -1,9 +1,26 @@ use proc_macro::TokenStream; use quote::quote; +use syn::{Attribute, Lit, Meta, NestedMeta, Result}; -pub fn impl_tag_aware_deserialize_macro(ast: &syn::DeriveInput) -> TokenStream { +pub fn impl_tag_aware_deserialize_macro( + ast: &syn::DeriveInput, +) -> Result { let name = &ast.ident; + if !has_serde_remote_self(&ast.attrs) { + return Err(syn::Error::new( + name.span(), + format!("{} must have `#[serde(remote = \"Self\")]` directive", name), + )); + } + + let Some(tag_value) = extract_tag_value(&ast.attrs) else { + return Err(syn::Error::new( + name.span(), + format!("{} must have `#[serde(tag = \"...\")]` directive", name), + )); + }; + let gen = quote! { impl<'de> serde::de::Deserialize<'de> for #name { fn deserialize(deserializer: D) -> std::result::Result @@ -13,7 +30,7 @@ use serde::de::Error; let this = serde_json::Value::deserialize(deserializer)?; - if let Some(found_tag) = this.get("type") { + if let Some(found_tag) = this.get(#tag_value) { if found_tag == stringify!(#name) { // now we can run _original_ deserialize return match #name::deserialize(this) { @@ -36,5 +53,40 @@ } } }; - gen.into() + Ok(gen.into()) +} + +/// Reads the `#[serde(tag = "...")]` value +fn extract_tag_value(attrs: &[Attribute]) -> Option { + let serde_attr = attrs.iter().find(|attr| attr.path.is_ident("serde"))?; + if let Ok(Meta::List(meta_list)) = serde_attr.parse_meta() { + for nested in meta_list.nested { + if let NestedMeta::Meta(Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("tag") { + if let Lit::Str(lit_str) = name_value.lit { + return Some(lit_str.value()); + } + } + } + } + } + None +} + +/// Checks for the `#[serde(remote = "Self")]` attribute +fn has_serde_remote_self(attrs: &[Attribute]) -> bool { + let Some(serde_attr) = attrs.iter().find(|attr| attr.path.is_ident("serde")) + else { + return false; + }; + if let Ok(Meta::List(meta_list)) = serde_attr.parse_meta() { + for nested in meta_list.nested { + if let NestedMeta::Meta(Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("remote") { + return matches!(name_value.lit, Lit::Str(lit_str) if lit_str.value() == "Self"); + } + } + } + } + false }