diff --git a/beacon_node/beacon_chain/src/attestation_aggregation.rs b/beacon_node/beacon_chain/src/attestation_aggregation.rs new file mode 100644 index 0000000000..4cac5056cb --- /dev/null +++ b/beacon_node/beacon_chain/src/attestation_aggregation.rs @@ -0,0 +1,129 @@ +use std::collections::{HashMap, HashSet}; +use types::{ + AggregateSignature, Attestation, AttestationData, BeaconState, Bitfield, ChainSpec, Signature, +}; + +const PHASE_0_CUSTODY_BIT: bool = false; + +pub struct AttestationAggregator { + store: HashMap, Attestation>, +} + +pub enum ProcessOutcome { + AggregationNotRequired, + Aggregated, + NewAttestationCreated, +} + +pub enum ProcessError { + BadValidatorIndex, + BadSignature, +} + +impl AttestationAggregator { + pub fn new() -> Self { + Self { + store: HashMap::new(), + } + } + + pub fn process_free_attestation( + &mut self, + state: &BeaconState, + attestation_data: &AttestationData, + signature: &Signature, + validator_index: u64, + ) -> Result { + let validator_index = validator_index as usize; + + let signable_message = attestation_data.signable_message(PHASE_0_CUSTODY_BIT); + let validator_pubkey = &state + .validator_registry + .get(validator_index) + .ok_or_else(|| ProcessError::BadValidatorIndex)? + .pubkey; + + if !signature.verify(&signable_message, &validator_pubkey) { + return Err(ProcessError::BadSignature); + } + + if let Some(existing_attestation) = self.store.get(&signable_message) { + if let Some(updated_attestation) = + aggregate_attestation(existing_attestation, signature, validator_index) + { + self.store.insert(signable_message, updated_attestation); + Ok(ProcessOutcome::Aggregated) + } else { + Ok(ProcessOutcome::AggregationNotRequired) + } + } else { + let mut aggregate_signature = AggregateSignature::new(); + aggregate_signature.add(signature); + let mut aggregation_bitfield = Bitfield::new(); + aggregation_bitfield.set(validator_index, true); + let new_attestation = Attestation { + data: attestation_data.clone(), + aggregation_bitfield, + custody_bitfield: Bitfield::new(), + aggregate_signature, + }; + self.store.insert(signable_message, new_attestation); + Ok(ProcessOutcome::NewAttestationCreated) + } + } + + /// Returns all known attestations which are: + /// + /// a) valid for the given state + /// b) not already in `state.latest_attestations`. + pub fn get_attestations_for_state( + &self, + state: &BeaconState, + spec: &ChainSpec, + ) -> Vec { + let mut known_attestation_data: HashSet = HashSet::new(); + + state.latest_attestations.iter().for_each(|attestation| { + known_attestation_data.insert(attestation.data.clone()); + }); + + self.store + .values() + .filter_map(|attestation| { + if state.validate_attestation(attestation, spec).is_ok() + && !known_attestation_data.contains(&attestation.data) + { + Some(attestation.clone()) + } else { + None + } + }) + .collect() + } +} + +fn aggregate_attestation( + existing_attestation: &Attestation, + signature: &Signature, + validator_index: usize, +) -> Option { + let already_signed = existing_attestation + .aggregation_bitfield + .get(validator_index) + .unwrap_or(false); + + if already_signed { + None + } else { + let mut aggregation_bitfield = existing_attestation.aggregation_bitfield.clone(); + aggregation_bitfield.set(validator_index, true); + let mut aggregate_signature = existing_attestation.aggregate_signature.clone(); + aggregate_signature.add(&signature); + + Some(Attestation { + aggregation_bitfield, + aggregate_signature, + ..existing_attestation.clone() + }) + } +} diff --git a/beacon_node/beacon_chain/src/lib.rs b/beacon_node/beacon_chain/src/lib.rs index bdb041b130..c4da9099f1 100644 --- a/beacon_node/beacon_chain/src/lib.rs +++ b/beacon_node/beacon_chain/src/lib.rs @@ -1,3 +1,4 @@ +mod attestation_aggregation; mod attestation_production; mod attestation_targets; mod block_graph; @@ -13,6 +14,7 @@ mod state_transition; use self::attestation_targets::AttestationTargets; use self::block_graph::BlockGraph; +use attestation_aggregation::AttestationAggregator; use db::{ stores::{BeaconBlockStore, BeaconStateStore}, ClientDB, DBError, @@ -73,6 +75,7 @@ pub struct BeaconChain { pub state_store: Arc>, pub slot_clock: U, pub block_graph: BlockGraph, + pub attestation_aggregator: RwLock, canonical_head: RwLock, finalized_head: RwLock, justified_head: RwLock, @@ -124,6 +127,7 @@ where genesis_state.clone(), state_root.clone(), )); + let attestation_aggregator = RwLock::new(AttestationAggregator::new()); let latest_attestation_targets = RwLock::new(AttestationTargets::new()); @@ -132,6 +136,7 @@ where state_store, slot_clock, block_graph, + attestation_aggregator, justified_head, finalized_head, canonical_head, diff --git a/eth2/types/src/attestation_data/mod.rs b/eth2/types/src/attestation_data/mod.rs index f22b4007f1..28504127ee 100644 --- a/eth2/types/src/attestation_data/mod.rs +++ b/eth2/types/src/attestation_data/mod.rs @@ -3,6 +3,7 @@ use crate::test_utils::TestRandom; use rand::RngCore; use serde_derive::Serialize; use ssz::{hash, Decodable, DecodeError, Encodable, SszStream, TreeHash}; +use std::hash::Hash; mod signing; @@ -17,7 +18,7 @@ pub const SSZ_ATTESTION_DATA_LENGTH: usize = { 32 // justified_block_root }; -#[derive(Debug, Clone, PartialEq, Default, Serialize)] +#[derive(Debug, Clone, PartialEq, Default, Serialize, Hash)] pub struct AttestationData { pub slot: u64, pub shard: u64, @@ -29,6 +30,8 @@ pub struct AttestationData { pub justified_block_root: Hash256, } +impl Eq for AttestationData {} + impl AttestationData { pub fn zero() -> Self { Self {