diff --git a/beacon_node/eth2-libp2p/src/rpc/methods.rs b/beacon_node/eth2-libp2p/src/rpc/methods.rs index a61fd8d26e..0123e5ed64 100644 --- a/beacon_node/eth2-libp2p/src/rpc/methods.rs +++ b/beacon_node/eth2-libp2p/src/rpc/methods.rs @@ -1,4 +1,4 @@ -use ssz::{Decodable, DecodeError, Encodable}; +use ssz::{impl_decode_via_from, impl_encode_via_from, Decodable, DecodeError, Encodable}; /// Available RPC methods types and ids. use ssz_derive::{Decode, Encode}; use types::{BeaconBlockBody, BeaconBlockHeader, Epoch, Hash256, Slot}; @@ -149,19 +149,8 @@ impl Into for GoodbyeReason { } } -impl Encodable for GoodbyeReason { - fn ssz_append(&self, s: &mut SszStream) { - let id: u64 = (*self).clone().into(); - id.ssz_append(s); - } -} - -impl Decodable for GoodbyeReason { - fn ssz_decode(bytes: &[u8], index: usize) -> Result<(Self, usize), DecodeError> { - let (id, index) = u64::ssz_decode(bytes, index)?; - Ok((Self::from(id), index)) - } -} +impl_encode_via_from!(GoodbyeReason, u64); +impl_decode_via_from!(GoodbyeReason, u64); /// Request a number of beacon block roots from a peer. #[derive(Encode, Decode, Clone, Debug, PartialEq)] diff --git a/beacon_node/eth2-libp2p/src/rpc/protocol.rs b/beacon_node/eth2-libp2p/src/rpc/protocol.rs index b5a695beab..02d774d9e1 100644 --- a/beacon_node/eth2-libp2p/src/rpc/protocol.rs +++ b/beacon_node/eth2-libp2p/src/rpc/protocol.rs @@ -1,6 +1,6 @@ use super::methods::*; use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; -use ssz::{ssz_encode, Decodable, DecodeError as SSZDecodeError, Encodable}; +use ssz::{impl_decode_via_from, impl_encode_via_from, ssz_encode, Decodable, Encodable}; use std::hash::{Hash, Hasher}; use std::io; use std::iter; @@ -72,18 +72,8 @@ impl Into for RequestId { } } -impl Encodable for RequestId { - fn ssz_append(&self, s: &mut SszStream) { - self.0.ssz_append(s); - } -} - -impl Decodable for RequestId { - fn ssz_decode(bytes: &[u8], index: usize) -> Result<(Self, usize), SSZDecodeError> { - let (id, index) = u64::ssz_decode(bytes, index)?; - Ok((Self::from(id), index)) - } -} +impl_encode_via_from!(RequestId, u64); +impl_decode_via_from!(RequestId, u64); /// The RPC types which are sent/received in this protocol. #[derive(Debug, Clone)] diff --git a/eth2/utils/ssz/src/lib.rs b/eth2/utils/ssz/src/lib.rs index 1d32e85013..e6e061e513 100644 --- a/eth2/utils/ssz/src/lib.rs +++ b/eth2/utils/ssz/src/lib.rs @@ -18,3 +18,90 @@ where { val.as_ssz_bytes() } + +#[macro_export] +macro_rules! impl_encode_via_from { + ($impl_type: ty, $from_type: ty) => { + impl Encodable for $impl_type { + fn is_ssz_fixed_len() -> bool { + <$from_type as Encodable>::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + <$from_type as Encodable>::ssz_fixed_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + let conv: $from_type = self.clone().into(); + + conv.ssz_append(buf) + } + } + }; +} + +#[macro_export] +macro_rules! impl_decode_via_from { + ($impl_type: ty, $from_type: tt) => { + impl Decodable for $impl_type { + fn is_ssz_fixed_len() -> bool { + <$from_type as Decodable>::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + <$from_type as Decodable>::ssz_fixed_len() + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + $from_type::from_ssz_bytes(bytes).and_then(|dec| Ok(dec.into())) + } + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate as ssz; + + #[derive(PartialEq, Debug, Clone, Copy)] + struct Wrapper(u64); + + impl From for Wrapper { + fn from(x: u64) -> Wrapper { + Wrapper(x) + } + } + + impl From for u64 { + fn from(x: Wrapper) -> u64 { + x.0 + } + } + + impl_encode_via_from!(Wrapper, u64); + impl_decode_via_from!(Wrapper, u64); + + #[test] + fn impl_encode_via_from() { + let check_encode = |a: u64, b: Wrapper| assert_eq!(a.as_ssz_bytes(), b.as_ssz_bytes()); + + check_encode(0, Wrapper(0)); + check_encode(1, Wrapper(1)); + check_encode(42, Wrapper(42)); + } + + #[test] + fn impl_decode_via_from() { + let check_decode = |bytes: Vec| { + let a = u64::from_ssz_bytes(&bytes).unwrap(); + let b = Wrapper::from_ssz_bytes(&bytes).unwrap(); + + assert_eq!(a, b.into()) + }; + + check_decode(vec![0, 0, 0, 0, 0, 0, 0, 0]); + check_decode(vec![1, 0, 0, 0, 0, 0, 0, 0]); + check_decode(vec![1, 0, 0, 0, 2, 0, 0, 0]); + } +}