diff --git a/consensus/ssz/tests/tests.rs b/consensus/ssz/tests/tests.rs index 475e1b8f67..72b71fcbf9 100644 --- a/consensus/ssz/tests/tests.rs +++ b/consensus/ssz/tests/tests.rs @@ -429,6 +429,13 @@ mod derive_macro { B(VariableB), } + #[derive(PartialEq, Debug, Encode)] + #[ssz(transparent)] + enum TwoVariableTransDirectTag { + A(VariableA), + B(VariableB), + } + #[derive(PartialEq, Debug, Encode)] struct TwoVariableTransStruct { a: TwoVariableTrans, @@ -441,6 +448,13 @@ mod derive_macro { B(VariableB), } + #[derive(PartialEq, Debug, Encode, Decode)] + #[ssz(union)] + enum TwoVariableUnionDirectTag { + A(VariableA), + B(VariableB), + } + #[derive(PartialEq, Debug, Encode, Decode)] struct TwoVariableUnionStruct { a: TwoVariableUnion, diff --git a/consensus/ssz_derive/src/lib.rs b/consensus/ssz_derive/src/lib.rs index 60dea1e9fb..82f15e4cfb 100644 --- a/consensus/ssz_derive/src/lib.rs +++ b/consensus/ssz_derive/src/lib.rs @@ -13,6 +13,12 @@ use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Ident}; /// extensions). const MAX_UNION_SELECTOR: u8 = 127; +const ENUM_TRANSPARENT: &str = "transparent"; +const ENUM_UNION: &str = "union"; +const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION]; +const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require a \"transparent\" or \"union\" attribute, \ + e.g., #[ssz(transparent)]"; + #[derive(Debug, FromDeriveInput)] #[darling(attributes(ssz))] struct StructOpts { @@ -20,6 +26,8 @@ struct StructOpts { enum_behaviour: Option, #[darling(default)] transparent: bool, + #[darling(default)] + union: bool, } /// Field-level configuration. @@ -33,35 +41,62 @@ struct FieldOpts { skip_deserializing: bool, } -const ENUM_TRANSPARENT: &str = "transparent"; -const ENUM_UNION: &str = "union"; -const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION]; -const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \ - e.g., #[ssz(enum_behaviour = \"transparent\")]"; - -enum EnumBehaviour { - Transparent, - Union, +struct Config { + struct_behaviour: StructBehaviour, + enum_behaviour: EnumBehaviour, } -impl EnumBehaviour { - pub fn new(s: &StructOpts) -> Option { - s.enum_behaviour - .as_ref() - .map(|behaviour_string| match behaviour_string.as_ref() { +impl Config { + fn read(item: &DeriveInput) -> Self { + let opts = StructOpts::from_derive_input(item).unwrap(); + + let enum_behaviour = match (opts.transparent, opts.union, &opts.enum_behaviour) { + (false, false, None) => EnumBehaviour::Unspecified, + (true, false, None) => EnumBehaviour::Transparent, + (false, true, None) => EnumBehaviour::Union, + (false, false, Some(behaviour_string)) => match behaviour_string.as_ref() { ENUM_TRANSPARENT => EnumBehaviour::Transparent, - ENUM_UNION if s.transparent => { - panic!("cannot use \"transparent\" and \"enum_behaviour(union)\" together") - } ENUM_UNION => EnumBehaviour::Union, other => panic!( "{} is an invalid enum_behaviour, use either {:?}", other, ENUM_VARIANTS ), - }) + }, + (true, true, _) => panic!("cannot provide both \"transparent\" and \"union\""), + (_, _, Some(_)) => { + panic!("\"enum_behaviour\" cannot be used with \"transparent\" or \"union\"") + } + }; + + // Don't allow `enum_behaviour` for structs. + if matches!(item.data, syn::Data::Struct(_)) && opts.enum_behaviour.is_some() { + panic!("cannot provide \"enum_behaviour\" for a struct") + } + + let struct_behaviour = if opts.transparent { + StructBehaviour::Transparent + } else { + StructBehaviour::Container + }; + + Self { + struct_behaviour, + enum_behaviour, + } } } +enum StructBehaviour { + Transparent, + Container, +} + +enum EnumBehaviour { + Transparent, + Union, + Unspecified, +} + fn parse_ssz_fields(struct_data: &syn::DataStruct) -> Vec<(&syn::Type, &syn::Ident, FieldOpts)> { struct_data .fields @@ -100,24 +135,17 @@ fn parse_ssz_fields(struct_data: &syn::DataStruct) -> Vec<(&syn::Type, &syn::Ide #[proc_macro_derive(Encode, attributes(ssz))] pub fn ssz_encode_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); - let opts = StructOpts::from_derive_input(&item).unwrap(); - let enum_opt = EnumBehaviour::new(&opts); + let config = Config::read(&item); match &item.data { - syn::Data::Struct(s) => { - if enum_opt.is_some() { - panic!("enum_behaviour is invalid for structs"); - } - - if opts.transparent { - ssz_encode_derive_struct_transparent(&item, s) - } else { - ssz_encode_derive_struct(&item, s) - } - } - syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { + syn::Data::Struct(s) => match config.struct_behaviour { + StructBehaviour::Transparent => ssz_encode_derive_struct_transparent(&item, s), + StructBehaviour::Container => ssz_encode_derive_struct(&item, s), + }, + syn::Data::Enum(s) => match config.enum_behaviour { EnumBehaviour::Transparent => ssz_encode_derive_enum_transparent(&item, s), EnumBehaviour::Union => ssz_encode_derive_enum_union(&item, s), + EnumBehaviour::Unspecified => panic!("{}", NO_ENUM_BEHAVIOUR_ERROR), }, _ => panic!("ssz_derive only supports structs and enums"), } @@ -259,9 +287,8 @@ fn ssz_encode_derive_struct_transparent( let (ty, ident, _field_opts) = ssz_fields .iter() - .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) - .next() - .unwrap(); + .find(|(_, _, field_opts)| !field_opts.skip_deserializing) + .expect("\"transparent\" struct must have at least one non-skipped field"); let output = quote! { impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { @@ -433,26 +460,20 @@ fn ssz_encode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum #[proc_macro_derive(Decode, attributes(ssz))] pub fn ssz_decode_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); - let opts = StructOpts::from_derive_input(&item).unwrap(); - let enum_opt = EnumBehaviour::new(&opts); + let config = Config::read(&item); match &item.data { - syn::Data::Struct(s) => { - if enum_opt.is_some() { - panic!("enum_behaviour is invalid for structs"); - } - if opts.transparent { - ssz_decode_derive_struct_transparent(&item, s) - } else { - ssz_decode_derive_struct(&item, s) - } - } - syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { + syn::Data::Struct(s) => match config.struct_behaviour { + StructBehaviour::Transparent => ssz_decode_derive_struct_transparent(&item, s), + StructBehaviour::Container => ssz_decode_derive_struct(&item, s), + }, + syn::Data::Enum(s) => match config.enum_behaviour { + EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, s), EnumBehaviour::Transparent => panic!( "Decode cannot be derived for enum_behaviour \"{}\", only \"{}\" is valid.", ENUM_TRANSPARENT, ENUM_UNION ), - EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, s), + EnumBehaviour::Unspecified => panic!("{}", NO_ENUM_BEHAVIOUR_ERROR), }, _ => panic!("ssz_derive only supports structs and enums"), }