Implement committee cache diffs

This commit is contained in:
Michael Sproul
2022-03-15 17:08:14 +11:00
parent 1a261e1d3b
commit ff649f0b26
8 changed files with 210 additions and 16 deletions

View File

@@ -136,6 +136,13 @@ pub enum Error {
},
#[cfg(feature = "milhouse")]
MilhouseError(milhouse::Error),
CommitteeCacheDiffInvalidEpoch {
prev_current_epoch: Epoch,
current_epoch: Epoch,
},
CommitteeCacheDiffUninitialized {
expected_epoch: Epoch,
},
}
/// Control whether an epoch-indexed field can be indexed at the next epoch or not.
@@ -1488,7 +1495,7 @@ impl<T: EthSpec> BeaconState<T> {
Ok(())
}
fn committee_cache_index(relative_epoch: RelativeEpoch) -> usize {
pub(crate) fn committee_cache_index(relative_epoch: RelativeEpoch) -> usize {
match relative_epoch {
RelativeEpoch::Previous => 0,
RelativeEpoch::Current => 1,

View File

@@ -1,9 +1,11 @@
use crate::{
BeaconBlockHeader, BeaconState, BeaconStateError as Error, BitVector, Checkpoint, Eth1Data,
EthSpec, ExecutionPayloadHeader, Fork, Hash256, ParticipationFlags, PendingAttestation, Slot,
SyncCommittee, Validator,
beacon_state::{CommitteeCache, CACHED_EPOCHS},
BeaconBlockHeader, BeaconState, BeaconStateError as Error, BitVector, Checkpoint, Epoch,
Eth1Data, EthSpec, ExecutionPayloadHeader, Fork, Hash256, ParticipationFlags,
PendingAttestation, Slot, SyncCommittee, Validator,
};
use milhouse::{CloneDiff, Diff, ListDiff, ResetListDiff, VectorDiff};
use safe_arith::SafeArith;
use ssz::{Decode, Encode};
use ssz_derive::{Decode, Encode};
use std::sync::Arc;
@@ -79,6 +81,19 @@ pub struct BeaconStateDiff<T: EthSpec> {
// Execution
latest_execution_payload_header: Maybe<CloneDiff<ExecutionPayloadHeader<T>>>,
// Committee caches
committee_caches: CommitteeCachesDiff,
}
/// Zero to three committee caches which update a `BeaconState`'s stored committee caches.
///
/// For most diffs which are taken relative to the previous epoch boundary state this diff
/// will contain a single committee cache.
#[derive(Debug, PartialEq, Encode, Decode)]
pub struct CommitteeCachesDiff {
current_epoch: Epoch,
caches: Vec<Arc<CommitteeCache>>,
}
fn optional_field_diff<
@@ -108,6 +123,100 @@ fn apply_optional_diff<X, D: Diff<Target = X, Error = milhouse::Error> + Encode
Ok(())
}
fn compute_committee_cache_dist(
current_epoch: Epoch,
prev_current_epoch: Epoch,
) -> Result<usize, Error> {
current_epoch
.safe_sub(prev_current_epoch)
.as_ref()
.map(Epoch::as_usize)
.map_err(|_| Error::CommitteeCacheDiffInvalidEpoch {
prev_current_epoch,
current_epoch,
})
}
/// Check that an array of committee caches is fully initialized with respect to `current_epoch`.
fn check_committee_caches(
caches: &[Arc<CommitteeCache>; CACHED_EPOCHS],
current_epoch: Epoch,
) -> Result<(), Error> {
for (i, cache) in caches.iter().enumerate() {
const CURRENT_EPOCH_OFFSET: u64 = 1;
let expected_epoch = Epoch::new(
current_epoch
.safe_add(i as u64)?
.as_u64()
.saturating_sub(CURRENT_EPOCH_OFFSET),
);
if !cache.is_initialized_at(expected_epoch) {
return Err(Error::CommitteeCacheDiffUninitialized { expected_epoch }).unwrap();
}
}
Ok(())
}
impl Diff for CommitteeCachesDiff {
// Diffs are applied wrt to the current epoch and the `state.committee_caches` array.
type Target = (Epoch, [Arc<CommitteeCache>; CACHED_EPOCHS]);
type Error = Error;
fn compute_diff(orig: &Self::Target, other: &Self::Target) -> Result<Self, Error> {
let (prev_current_epoch, prev_caches) = orig;
let (current_epoch, caches) = other;
// Sanity check the inputs to ensure we can compute a sensible diff.
check_committee_caches(&prev_caches, *prev_current_epoch)?;
check_committee_caches(&caches, *current_epoch)?;
let dist = compute_committee_cache_dist(*current_epoch, *prev_current_epoch)?;
// The distance determines the number of caches that are unique to the new cache array.
// If the epoch distance is 0 then there are no new caches, if it's 1 then only the last
// element of the cache is new, and so on up to the maximum of `CACHED_EPOCHS` at which
// point the entire array is new.
let new_caches = (CACHED_EPOCHS.saturating_sub(dist)..CACHED_EPOCHS)
.map(|i| {
caches
.get(i)
.cloned()
.ok_or(Error::CommitteeCachesOutOfBounds(i))
})
.collect::<Result<Vec<_>, _>>()?;
assert_eq!(new_caches.len(), std::cmp::min(CACHED_EPOCHS, dist));
Ok(CommitteeCachesDiff {
current_epoch: *current_epoch,
caches: new_caches,
})
}
fn apply_diff(self, target: &mut Self::Target) -> Result<(), Error> {
let (prev_current_epoch, caches) = target;
let dist = compute_committee_cache_dist(self.current_epoch, *prev_current_epoch)?;
let capped_dist = std::cmp::min(CACHED_EPOCHS, dist);
// Rotate caches for the epoch advance. This moves the caches that are still relevant into
// position. The irrelevant caches will be overwritten in the next step.
caches.rotate_left(capped_dist);
let base = CACHED_EPOCHS.saturating_sub(capped_dist);
for (i, cache) in self.caches.into_iter().enumerate() {
let cache_index = base.safe_add(i)?;
*caches
.get_mut(cache_index)
.ok_or(Error::CommitteeCachesOutOfBounds(cache_index))? = cache;
}
*prev_current_epoch = self.current_epoch;
// Sanity check the diff application.
check_committee_caches(caches, self.current_epoch)
}
}
impl<T: EthSpec> Diff for BeaconStateDiff<T> {
type Target = BeaconState<T>;
type Error = Error;
@@ -115,6 +224,18 @@ impl<T: EthSpec> Diff for BeaconStateDiff<T> {
// FIXME(sproul): proc macro
fn compute_diff(orig: &Self::Target, other: &Self::Target) -> Result<Self, Error> {
// FIXME(sproul): consider cross-variant diffs
// Compute committee caches diff.
let prev_current_epoch = orig.current_epoch();
let current_epoch = other.current_epoch();
let orig_committee_caches = orig.committee_caches().clone();
let new_committee_caches = other.committee_caches().clone();
let committee_caches = CommitteeCachesDiff::compute_diff(
&(prev_current_epoch, orig_committee_caches),
&(current_epoch, new_committee_caches),
)?;
Ok(BeaconStateDiff {
genesis_time: <_>::compute_diff(&orig.genesis_time(), &other.genesis_time())?,
genesis_validators_root: <_>::compute_diff(
@@ -192,10 +313,13 @@ impl<T: EthSpec> Diff for BeaconStateDiff<T> {
other,
BeaconState::latest_execution_payload_header,
)?,
committee_caches,
})
}
fn apply_diff(self, target: &mut BeaconState<T>) -> Result<(), Error> {
let prev_current_epoch = target.current_epoch();
self.genesis_time.apply_diff(target.genesis_time_mut())?;
self.genesis_validators_root
.apply_diff(target.genesis_validators_root_mut())?;
@@ -250,6 +374,12 @@ impl<T: EthSpec> Diff for BeaconStateDiff<T> {
self.latest_execution_payload_header,
target.latest_execution_payload_header_mut(),
)?;
// Apply committee caches diff.
let mut committee_caches = (prev_current_epoch, target.committee_caches().clone());
self.committee_caches.apply_diff(&mut committee_caches)?;
*target.committee_caches_mut() = committee_caches.1;
Ok(())
}
}