From 9797ff88a6ee909db008a7764d1af7b134e537b5 Mon Sep 17 00:00:00 2001 From: Paul Hauner Date: Sun, 16 Oct 2022 20:44:21 -0500 Subject: [PATCH] Add transparent support --- Cargo.lock | 2 +- consensus/ssz/Cargo.toml | 2 +- consensus/ssz/tests/tests.rs | 43 ++++++++ consensus/ssz_derive/Cargo.toml | 4 +- consensus/ssz_derive/src/lib.rs | 171 +++++++++++++++++++++++++++++--- 5 files changed, 206 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d0e3622e77..4200bc27b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1829,7 +1829,7 @@ dependencies = [ [[package]] name = "eth2_ssz_derive" -version = "0.3.0" +version = "0.3.1" dependencies = [ "darling", "proc-macro2", diff --git a/consensus/ssz/Cargo.toml b/consensus/ssz/Cargo.toml index a153c2efc1..e521853c21 100644 --- a/consensus/ssz/Cargo.toml +++ b/consensus/ssz/Cargo.toml @@ -10,7 +10,7 @@ license = "Apache-2.0" name = "ssz" [dev-dependencies] -eth2_ssz_derive = "0.3.0" +eth2_ssz_derive = "0.3.1" [dependencies] ethereum-types = "0.12.1" diff --git a/consensus/ssz/tests/tests.rs b/consensus/ssz/tests/tests.rs index e41fc15dd4..475e1b8f67 100644 --- a/consensus/ssz/tests/tests.rs +++ b/consensus/ssz/tests/tests.rs @@ -375,6 +375,7 @@ mod derive_macro { use ssz::{Decode, Encode}; use ssz_derive::{Decode, Encode}; use std::fmt::Debug; + use std::marker::PhantomData; fn assert_encode(item: &T, bytes: &[u8]) { assert_eq!(item.as_ssz_bytes(), bytes); @@ -511,4 +512,46 @@ mod derive_macro { assert_encode_decode(&TwoVecUnion::A(vec![0, 1]), &[0, 0, 1]); assert_encode_decode(&TwoVecUnion::B(vec![0, 1]), &[1, 0, 1]); } + + #[derive(PartialEq, Debug, Encode, Decode)] + #[ssz(transparent)] + struct TransparentStruct { + inner: Vec, + } + + impl TransparentStruct { + fn new(inner: u8) -> Self { + Self { inner: vec![inner] } + } + } + + #[test] + fn transparent_struct() { + assert_encode_decode(&TransparentStruct::new(42), &vec![42_u8].as_ssz_bytes()); + } + + #[derive(PartialEq, Debug, Encode, Decode)] + #[ssz(transparent)] + struct TransparentStructSkippedField { + inner: Vec, + #[ssz(skip_serializing, skip_deserializing)] + skipped: PhantomData, + } + + impl TransparentStructSkippedField { + fn new(inner: u8) -> Self { + Self { + inner: vec![inner], + skipped: PhantomData, + } + } + } + + #[test] + fn transparent_struct_skipped_field() { + assert_encode_decode( + &TransparentStructSkippedField::new(42), + &vec![42_u8].as_ssz_bytes(), + ); + } } diff --git a/consensus/ssz_derive/Cargo.toml b/consensus/ssz_derive/Cargo.toml index cac617d391..57d65e5573 100644 --- a/consensus/ssz_derive/Cargo.toml +++ b/consensus/ssz_derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "eth2_ssz_derive" -version = "0.3.0" +version = "0.3.1" authors = ["Paul Hauner "] edition = "2021" description = "Procedural derive macros to accompany the eth2_ssz crate." @@ -15,3 +15,5 @@ syn = "1.0.42" proc-macro2 = "1.0.23" quote = "1.0.7" darling = "0.13.0" + +[dev-dependencies] diff --git a/consensus/ssz_derive/src/lib.rs b/consensus/ssz_derive/src/lib.rs index a5a5a0dddf..60dea1e9fb 100644 --- a/consensus/ssz_derive/src/lib.rs +++ b/consensus/ssz_derive/src/lib.rs @@ -18,6 +18,8 @@ const MAX_UNION_SELECTOR: u8 = 127; struct StructOpts { #[darling(default)] enum_behaviour: Option, + #[darling(default)] + transparent: bool, } /// Field-level configuration. @@ -43,15 +45,20 @@ enum EnumBehaviour { } impl EnumBehaviour { - pub fn new(s: Option) -> Option { - s.map(|s| match s.as_ref() { - ENUM_TRANSPARENT => EnumBehaviour::Transparent, - ENUM_UNION => EnumBehaviour::Union, - other => panic!( - "{} is an invalid enum_behaviour, use either {:?}", - other, ENUM_VARIANTS - ), - }) + pub fn new(s: &StructOpts) -> Option { + s.enum_behaviour + .as_ref() + .map(|behaviour_string| match behaviour_string.as_ref() { + ENUM_TRANSPARENT => EnumBehaviour::Transparent, + ENUM_UNION if s.transparent => { + panic!("cannot use \"transparent\" and \"enum_behaviour(union)\" together") + } + ENUM_UNION => EnumBehaviour::Union, + other => panic!( + "{} is an invalid enum_behaviour, use either {:?}", + other, ENUM_VARIANTS + ), + }) } } @@ -94,14 +101,19 @@ fn parse_ssz_fields(struct_data: &syn::DataStruct) -> Vec<(&syn::Type, &syn::Ide pub fn ssz_encode_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); let opts = StructOpts::from_derive_input(&item).unwrap(); - let enum_opt = EnumBehaviour::new(opts.enum_behaviour); + let enum_opt = EnumBehaviour::new(&opts); match &item.data { syn::Data::Struct(s) => { if enum_opt.is_some() { panic!("enum_behaviour is invalid for structs"); } - ssz_encode_derive_struct(&item, s) + + if opts.transparent { + ssz_encode_derive_struct_transparent(&item, s) + } else { + ssz_encode_derive_struct(&item, s) + } } syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { EnumBehaviour::Transparent => ssz_encode_derive_enum_transparent(&item, s), @@ -219,6 +231,60 @@ fn ssz_encode_derive_struct(derive_input: &DeriveInput, struct_data: &DataStruct output.into() } +/// Derive `ssz::Encode` "transparently" for a struct which has exactly one non-skipped field. +/// +/// The single field is encoded directly, making the outermost `struct` transparent. +/// +/// ## Field attributes +/// +/// - `#[ssz(skip_serializing)]`: the field will not be serialized. +fn ssz_encode_derive_struct_transparent( + derive_input: &DeriveInput, + struct_data: &DataStruct, +) -> TokenStream { + let name = &derive_input.ident; + let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); + let ssz_fields = parse_ssz_fields(struct_data); + let num_fields = ssz_fields + .iter() + .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) + .count(); + + if num_fields != 1 { + panic!( + "A \"transparent\" struct must have exactly one non-skipped field ({} fields found)", + num_fields + ); + } + + let (ty, ident, _field_opts) = ssz_fields + .iter() + .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) + .next() + .unwrap(); + + let output = quote! { + impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + <#ty as ssz::Encode>::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + <#ty as ssz::Encode>::ssz_fixed_len() + } + + fn ssz_bytes_len(&self) -> usize { + self.#ident.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + self.#ident.ssz_append(buf) + } + } + }; + output.into() +} + /// Derive `ssz::Encode` for an enum in the "transparent" method. /// /// The "transparent" method is distinct from the "union" method specified in the SSZ specification. @@ -368,14 +434,18 @@ fn ssz_encode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum pub fn ssz_decode_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); let opts = StructOpts::from_derive_input(&item).unwrap(); - let enum_opt = EnumBehaviour::new(opts.enum_behaviour); + let enum_opt = EnumBehaviour::new(&opts); match &item.data { syn::Data::Struct(s) => { if enum_opt.is_some() { panic!("enum_behaviour is invalid for structs"); } - ssz_decode_derive_struct(&item, s) + if opts.transparent { + ssz_decode_derive_struct_transparent(&item, s) + } else { + ssz_decode_derive_struct(&item, s) + } } syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { EnumBehaviour::Transparent => panic!( @@ -545,6 +615,81 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok output.into() } +/// Implements `ssz::Decode` "transparently" for a `struct` with exactly one non-skipped field. +/// +/// The bytes will be decoded as if they are the inner field, without the outmost struct. The +/// outermost struct will then be applied artificially. +/// +/// ## Field attributes +/// +/// - `#[ssz(skip_deserializing)]`: during de-serialization the field will be instantiated from a +/// `Default` implementation. The decoder will assume that the field was not serialized at all +/// (e.g., if it has been serialized, an error will be raised instead of `Default` overriding it). +fn ssz_decode_derive_struct_transparent( + item: &DeriveInput, + struct_data: &DataStruct, +) -> TokenStream { + let name = &item.ident; + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + let ssz_fields = parse_ssz_fields(struct_data); + let num_fields = ssz_fields + .iter() + .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) + .count(); + + if num_fields != 1 { + panic!( + "A \"transparent\" struct must have exactly one non-skipped field ({} fields found)", + num_fields + ); + } + + let mut field_names = vec![]; + let mut fields = vec![]; + let mut wrapped_type = None; + + for (ty, ident, field_opts) in ssz_fields { + field_names.push(quote! { + #ident + }); + + if field_opts.skip_deserializing { + fields.push(quote! { + #ident: <_>::default(), + }); + } else { + fields.push(quote! { + #ident: <_>::from_ssz_bytes(bytes)?, + }); + wrapped_type = Some(ty); + } + } + + let ty = wrapped_type.unwrap(); + + let output = quote! { + impl #impl_generics ssz::Decode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + <#ty as ssz::Decode>::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + <#ty as ssz::Decode>::ssz_fixed_len() + } + + fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result { + Ok(Self { + #( + #fields + )* + + }) + } + } + }; + output.into() +} + /// Derive `ssz::Decode` for an `enum` following the "union" SSZ spec. fn ssz_decode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum) -> TokenStream { let name = &derive_input.ident;