1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// https://github.com/rust-lang/rust/issues/13101

use proc_macro2;

use ast;
use attr;
use matcher;
use syn;
use utils;

/// Derive `Eq` for `input`.
pub fn derive_eq(input: &ast::Input) -> proc_macro2::TokenStream {
    let name = &input.ident;

    let eq_trait_path = eq_trait_path();
    let generics = utils::build_impl_generics(
        input,
        &eq_trait_path,
        needs_eq_bound,
        |field| field.eq_bound(),
        |input| input.eq_bound(),
    );
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    quote! {
        #[allow(unused_qualifications)]
        impl #impl_generics #eq_trait_path for #name #ty_generics #where_clause {}
    }
}

/// Derive `PartialEq` for `input`.
pub fn derive_partial_eq(input: &ast::Input) -> Result<proc_macro2::TokenStream, String> {
    if let ast::Body::Enum(_) = input.body {
        if !input.attrs.partial_eq_on_enum() {
            return Err(
                "can't use `#[derivative(PartialEq)]` on an enumeration without \
                 `feature_allow_slow_enum`; see the documentation for more details"
                    .into(),
            );
        }
    }

    let body = matcher::Matcher::new(matcher::BindingStyle::Ref)
        .with_name("__self".into())
        .build_arms(input, |_, outer_arm_name, _, _, outer_bis| {
            let body = matcher::Matcher::new(matcher::BindingStyle::Ref)
                .with_name("__other".into())
                .build_arms(input, |_, inner_arm_name, _, _, inner_bis| {
                    if outer_arm_name == inner_arm_name {
                        let cmp = outer_bis.iter().zip(inner_bis).map(|(o, i)| {
                            let outer_name = &o.ident;
                            let inner_name = &i.ident;

                            if o.field.attrs.ignore_partial_eq() {
                                None
                            } else if let Some(compare_fn) = o.field.attrs.partial_eq_compare_with()
                            {
                                Some(quote!(&& #compare_fn(#outer_name, #inner_name)))
                            } else {
                                Some(quote!(&& #outer_name == #inner_name))
                            }
                        });

                        quote!(true #(#cmp)*)
                    } else {
                        quote!(false)
                    }
                });

            quote! {
                match *other {
                    #body
                }
            }
        });

    let name = &input.ident;

    let partial_eq_trait_path = partial_eq_trait_path();
    let generics = utils::build_impl_generics(
        input,
        &partial_eq_trait_path,
        needs_partial_eq_bound,
        |field| field.partial_eq_bound(),
        |input| input.partial_eq_bound(),
    );
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    Ok(quote! {
        #[allow(unused_qualifications)]
        impl #impl_generics #partial_eq_trait_path for #name #ty_generics #where_clause {
            fn eq(&self, other: &Self) -> bool {
                match *self {
                    #body
                }
            }
        }
    })
}

fn needs_partial_eq_bound(attrs: &attr::Field) -> bool {
    !attrs.ignore_partial_eq() && attrs.partial_eq_bound().is_none()
}

fn needs_eq_bound(attrs: &attr::Field) -> bool {
    !attrs.ignore_partial_eq() && attrs.eq_bound().is_none()
}

/// Return the path of the `Eq` trait, that is `::std::cmp::Eq`.
fn eq_trait_path() -> syn::Path {
    if cfg!(feature = "use_core") {
        parse_quote!(::core::cmp::Eq)
    } else {
        parse_quote!(::std::cmp::Eq)
    }
}

/// Return the path of the `PartialEq` trait, that is `::std::cmp::PartialEq`.
fn partial_eq_trait_path() -> syn::Path {
    if cfg!(feature = "use_core") {
        parse_quote!(::core::cmp::PartialEq)
    } else {
        parse_quote!(::std::cmp::PartialEq)
    }
}