use proc_macro2::{Span, TokenStream, TokenTree};
use quote::{quote, ToTokens};
use syn::parse_quote;
use syn::{Data, DeriveInput, Fields};
use crate::helpers::{non_enum_error, strum_discriminants_passthrough_error, HasTypeProperties};
const ATTRIBUTES_TO_COPY: &[&str] = &["doc", "cfg", "allow", "deny", "strum_discriminants"];
pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let vis = &ast.vis;
let variants = match &ast.data {
Data::Enum(v) => &v.variants,
_ => return Err(non_enum_error()),
};
let type_properties = ast.get_type_properties()?;
let derives = type_properties.discriminant_derives;
let derives = quote! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, #(#derives),*)]
};
let default_name = syn::Ident::new(&format!("{}Discriminants", name), Span::call_site());
let discriminants_name = type_properties.discriminant_name.unwrap_or(default_name);
let discriminants_vis = type_properties
.discriminant_vis
.unwrap_or_else(|| vis.clone());
let pass_though_attributes = type_properties.discriminant_others;
let mut discriminants = Vec::new();
for variant in variants {
let ident = &variant.ident;
let attrs = variant
.attrs
.iter()
.filter(|attr| {
ATTRIBUTES_TO_COPY
.iter()
.any(|attr_whitelisted| attr.path.is_ident(attr_whitelisted))
})
.map(|attr| {
if attr.path.is_ident("strum_discriminants") {
let passthrough_group = attr
.tokens
.clone()
.into_iter()
.next()
.ok_or_else(|| strum_discriminants_passthrough_error(attr))?;
let passthrough_attribute = match passthrough_group {
TokenTree::Group(ref group) => group.stream(),
_ => {
return Err(strum_discriminants_passthrough_error(&passthrough_group));
}
};
if passthrough_attribute.is_empty() {
return Err(strum_discriminants_passthrough_error(&passthrough_group));
}
Ok(quote! { #[#passthrough_attribute] })
} else {
Ok(attr.to_token_stream())
}
})
.collect::<Result<Vec<_>, _>>()?;
discriminants.push(quote! { #(#attrs)* #ident });
}
let arms = variants
.iter()
.map(|variant| {
let ident = &variant.ident;
let params = match &variant.fields {
Fields::Unit => quote! {},
Fields::Unnamed(_fields) => {
quote! { (..) }
}
Fields::Named(_fields) => {
quote! { { .. } }
}
};
quote! { #name::#ident #params => #discriminants_name::#ident }
})
.collect::<Vec<_>>();
let from_fn_body = quote! { match val { #(#arms),* } };
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
let impl_from = quote! {
impl #impl_generics ::core::convert::From< #name #ty_generics > for #discriminants_name #where_clause {
fn from(val: #name #ty_generics) -> #discriminants_name {
#from_fn_body
}
}
};
let impl_from_ref = {
let mut generics = ast.generics.clone();
let lifetime = parse_quote!('_enum);
let enum_life = quote! { & #lifetime };
generics.params.push(lifetime);
let (impl_generics, _, _) = generics.split_for_impl();
quote! {
impl #impl_generics ::core::convert::From< #enum_life #name #ty_generics > for #discriminants_name #where_clause {
fn from(val: #enum_life #name #ty_generics) -> #discriminants_name {
#from_fn_body
}
}
}
};
Ok(quote! {
#derives
#(#[ #pass_though_attributes ])*
#discriminants_vis enum #discriminants_name {
#(#discriminants),*
}
#impl_from
#impl_from_ref
})
}