use std::iter;
use proc_macro2::Ident;
use syn::{
spanned::Spanned,
visit::{self, Visit},
Generics, Result, Type, TypePath,
};
use crate::utils::{self, CustomTraitBound};
struct ContainIdents<'a> {
result: bool,
idents: &'a [Ident],
}
impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| id == i) {
self.result = true;
}
}
}
fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool {
let mut visitor = ContainIdents { result: false, idents };
visitor.visit_type(ty);
visitor.result
}
struct TypePathStartsWithIdent<'a> {
result: bool,
ident: &'a Ident,
}
impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
fn visit_type_path(&mut self, i: &'ast TypePath) {
if let Some(segment) = i.path.segments.first() {
if &segment.ident == self.ident {
self.result = true;
return
}
}
visit::visit_type_path(self, i);
}
}
fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool {
let mut visitor = TypePathStartsWithIdent { result: false, ident };
visitor.visit_type_path(ty);
visitor.result
}
fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
let mut visitor = TypePathStartsWithIdent { result: false, ident };
visitor.visit_type(ty);
visitor.result
}
struct FindTypePathsNotStartOrContainIdent<'a> {
result: Vec<TypePath>,
ident: &'a Ident,
}
impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> {
fn visit_type_path(&mut self, i: &'ast TypePath) {
if type_path_or_sub_starts_with_ident(i, &self.ident) {
visit::visit_type_path(self, i);
} else {
self.result.push(i.clone());
}
}
}
fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec<TypePath> {
let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident };
visitor.visit_type(ty);
visitor.result
}
pub fn add<N>(
input_ident: &Ident,
generics: &mut Generics,
data: &syn::Data,
custom_trait_bound: Option<CustomTraitBound<N>>,
codec_bound: syn::Path,
codec_skip_bound: Option<syn::Path>,
dumb_trait_bounds: bool,
crate_path: &syn::Path,
) -> Result<()> {
let skip_type_params = match custom_trait_bound {
Some(CustomTraitBound::SpecifiedBounds { bounds, .. }) => {
generics.make_where_clause().predicates.extend(bounds);
return Ok(())
},
Some(CustomTraitBound::SkipTypeParams { type_names, .. }) =>
type_names.into_iter().collect::<Vec<_>>(),
None => Vec::new(),
};
let ty_params = generics
.type_params()
.filter_map(|tp| {
skip_type_params.iter().all(|skip| skip != &tp.ident).then(|| tp.ident.clone())
})
.collect::<Vec<_>>();
if ty_params.is_empty() {
return Ok(())
}
let codec_types =
get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?;
let compact_types = collect_types(&data, utils::is_compact)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.collect::<Vec<_>>();
let skip_types = if codec_skip_bound.is_some() {
let needs_default_bound = |f: &syn::Field| utils::should_skip(&f.attrs);
collect_types(&data, needs_default_bound)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.collect::<Vec<_>>()
} else {
Vec::new()
};
if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() {
let where_clause = generics.make_where_clause();
codec_types
.into_iter()
.for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #codec_bound)));
let has_compact_bound: syn::Path = parse_quote!(#crate_path::HasCompact);
compact_types
.into_iter()
.for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound)));
skip_types.into_iter().for_each(|ty| {
let codec_skip_bound = codec_skip_bound.as_ref();
where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound))
});
}
Ok(())
}
fn get_types_to_add_trait_bound(
input_ident: &Ident,
data: &syn::Data,
ty_params: &[Ident],
dumb_trait_bound: bool,
) -> Result<Vec<Type>> {
if dumb_trait_bound {
Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect())
} else {
let needs_codec_bound = |f: &syn::Field| {
!utils::is_compact(f) &&
utils::get_encoded_as_type(f).is_none() &&
!utils::should_skip(&f.attrs)
};
let res = collect_types(&data, needs_codec_bound)?
.into_iter()
.filter(|ty| type_contain_idents(ty, &ty_params))
.flat_map(|ty| {
find_type_paths_not_start_or_contain_ident(&ty, input_ident)
.into_iter()
.map(|ty| Type::Path(ty.clone()))
.filter(|ty| type_contain_idents(ty, &ty_params))
.chain(iter::once(ty))
})
.filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident))
.collect();
Ok(res)
}
}
fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Result<Vec<syn::Type>> {
use syn::*;
let types = match *data {
Data::Struct(ref data) => match &data.fields {
| Fields::Named(FieldsNamed { named: fields, .. }) |
Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
Fields::Unit => Vec::new(),
},
Data::Enum(ref data) => data
.variants
.iter()
.filter(|variant| !utils::should_skip(&variant.attrs))
.flat_map(|variant| match &variant.fields {
| Fields::Named(FieldsNamed { named: fields, .. }) |
Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
Fields::Unit => Vec::new(),
})
.collect(),
Data::Union(ref data) =>
return Err(Error::new(data.union_token.span(), "Union types are not supported.")),
};
Ok(types)
}