diff --git a/eth2/utils/ssz/src/cached_tree_hash.rs b/eth2/utils/ssz/src/cached_tree_hash.rs index 6535e5cdaa..caafaa2cfa 100644 --- a/eth2/utils/ssz/src/cached_tree_hash.rs +++ b/eth2/utils/ssz/src/cached_tree_hash.rs @@ -8,15 +8,17 @@ const HASHSIZE: usize = 32; const MERKLE_HASH_CHUNCK: usize = 2 * BYTES_PER_CHUNK; pub trait CachedTreeHash { + type Item: CachedTreeHash; + fn build_cache_bytes(&self) -> Vec; + /// Return the number of bytes when this element is encoded as raw SSZ _without_ length + /// prefixes. fn num_bytes(&self) -> usize; - fn max_num_leaves(&self) -> usize; - fn cached_hash_tree_root( &self, - other: &Self, + other: &Self::Item, cache: &mut TreeHashCache, chunk: usize, ) -> Option; @@ -45,6 +47,18 @@ impl TreeHashCache { }) } + pub fn maybe_update_chunk(&mut self, chunk: usize, to: &[u8]) -> Option<()> { + let start = chunk * BYTES_PER_CHUNK; + let end = start + BYTES_PER_CHUNK; + + if !self.chunk_equals(chunk, to)? { + self.cache.get_mut(start..end)?.copy_from_slice(to); + self.chunk_modified[chunk] = true; + } + + Some(()) + } + pub fn modify_chunk(&mut self, chunk: usize, to: &[u8]) -> Option<()> { let start = chunk * BYTES_PER_CHUNK; let end = start + BYTES_PER_CHUNK; @@ -56,6 +70,13 @@ impl TreeHashCache { Some(()) } + pub fn chunk_equals(&mut self, chunk: usize, other: &[u8]) -> Option { + let start = chunk * BYTES_PER_CHUNK; + let end = start + BYTES_PER_CHUNK; + + Some(self.cache.get(start..end)? == other) + } + pub fn changed(&self, chunk: usize) -> Option { self.chunk_modified.get(chunk).cloned() } @@ -119,7 +140,7 @@ pub fn merkleize(values: Vec) -> Vec { } pub fn sanitise_bytes(mut bytes: Vec) -> Vec { - let present_leaves = num_leaves(bytes.len()); + let present_leaves = num_unsanitized_leaves(bytes.len()); let required_leaves = present_leaves.next_power_of_two(); if (present_leaves != required_leaves) | last_leaf_needs_padding(bytes.len()) { @@ -133,8 +154,15 @@ fn last_leaf_needs_padding(num_bytes: usize) -> bool { num_bytes % HASHSIZE != 0 } -fn num_leaves(num_bytes: usize) -> usize { - num_bytes / HASHSIZE +/// Rounds up +fn num_unsanitized_leaves(num_bytes: usize) -> usize { + (num_bytes + HASHSIZE - 1) / HASHSIZE +} + +/// Rounds up +fn num_sanitized_leaves(num_bytes: usize) -> usize { + let leaves = (num_bytes + HASHSIZE - 1) / HASHSIZE; + leaves.next_power_of_two() } fn num_bytes(num_leaves: usize) -> usize { diff --git a/eth2/utils/ssz/src/cached_tree_hash/impls.rs b/eth2/utils/ssz/src/cached_tree_hash/impls.rs index b6b0d463aa..b27d28c4b7 100644 --- a/eth2/utils/ssz/src/cached_tree_hash/impls.rs +++ b/eth2/utils/ssz/src/cached_tree_hash/impls.rs @@ -1,7 +1,9 @@ use super::*; -use crate::ssz_encode; +use crate::{ssz_encode, Encodable}; impl CachedTreeHash for u64 { + type Item = Self; + fn build_cache_bytes(&self) -> Vec { merkleize(ssz_encode(self)) } @@ -10,10 +12,6 @@ impl CachedTreeHash for u64 { 8 } - fn max_num_leaves(&self) -> usize { - 1 - } - fn cached_hash_tree_root( &self, other: &Self, @@ -28,3 +26,65 @@ impl CachedTreeHash for u64 { Some(chunk + 1) } } + +impl CachedTreeHash for Vec +where + T: CachedTreeHash + Encodable, +{ + type Item = Self; + + fn build_cache_bytes(&self) -> Vec { + let num_packed_bytes = self.num_bytes(); + let num_leaves = num_sanitized_leaves(num_packed_bytes); + + let mut packed = Vec::with_capacity(num_leaves * HASHSIZE); + + for item in self { + packed.append(&mut ssz_encode(item)); + } + + let packed = sanitise_bytes(packed); + + merkleize(packed) + } + + fn num_bytes(&self) -> usize { + self.iter().fold(0, |acc, item| acc + item.num_bytes()) + } + + fn cached_hash_tree_root( + &self, + other: &Self::Item, + cache: &mut TreeHashCache, + chunk: usize, + ) -> Option { + let num_packed_bytes = self.num_bytes(); + let num_leaves = num_sanitized_leaves(num_packed_bytes); + + if num_leaves != num_sanitized_leaves(other.num_bytes()) { + panic!("Need to handle a change in leaf count"); + } + + let mut packed = Vec::with_capacity(num_leaves * HASHSIZE); + + // TODO: try and avoid fully encoding the whole list + for item in self { + packed.append(&mut ssz_encode(item)); + } + + let packed = sanitise_bytes(packed); + + let num_nodes = num_nodes(num_leaves); + let num_internal_nodes = num_nodes - num_leaves; + + { + let mut chunk = chunk + num_internal_nodes; + for new_chunk_bytes in packed.chunks(HASHSIZE) { + cache.maybe_update_chunk(chunk, new_chunk_bytes)?; + chunk += 1; + } + } + + Some(chunk + num_nodes) + } +} diff --git a/eth2/utils/ssz/src/cached_tree_hash/tests.rs b/eth2/utils/ssz/src/cached_tree_hash/tests.rs index 79665f89de..f4a4b1d463 100644 --- a/eth2/utils/ssz/src/cached_tree_hash/tests.rs +++ b/eth2/utils/ssz/src/cached_tree_hash/tests.rs @@ -1,5 +1,5 @@ use super::*; -use int_to_bytes::int_to_bytes32; +use int_to_bytes::{int_to_bytes32, int_to_bytes8}; #[derive(Clone)] pub struct Inner { @@ -10,6 +10,8 @@ pub struct Inner { } impl CachedTreeHash for Inner { + type Item = Self; + fn build_cache_bytes(&self) -> Vec { let mut leaves = vec![]; @@ -21,15 +23,6 @@ impl CachedTreeHash for Inner { merkleize(leaves) } - fn max_num_leaves(&self) -> usize { - let mut leaves = 0; - leaves += self.a.max_num_leaves(); - leaves += self.b.max_num_leaves(); - leaves += self.c.max_num_leaves(); - leaves += self.d.max_num_leaves(); - leaves - } - fn num_bytes(&self) -> usize { let mut bytes = 0; bytes += self.a.num_bytes(); @@ -45,7 +38,12 @@ impl CachedTreeHash for Inner { cache: &mut TreeHashCache, chunk: usize, ) -> Option { - let num_leaves = self.max_num_leaves(); + let mut num_leaves: usize = 0; + num_leaves += num_unsanitized_leaves(self.a.num_bytes()); + num_leaves += num_unsanitized_leaves(self.b.num_bytes()); + num_leaves += num_unsanitized_leaves(self.c.num_bytes()); + num_leaves += num_unsanitized_leaves(self.d.num_bytes()); + let num_nodes = num_nodes(num_leaves); let num_internal_nodes = num_nodes - num_leaves; @@ -78,6 +76,26 @@ fn join(many: Vec>) -> Vec { all } +#[test] +fn vec_of_u64() { + let data = join(vec![ + int_to_bytes8(1), + int_to_bytes8(2), + int_to_bytes8(3), + int_to_bytes8(4), + int_to_bytes8(5), + vec![0; 32 - 8], // padding + ]); + + let expected = merkleize(data); + + let my_vec = vec![1, 2, 3, 4, 5]; + + let cache = my_vec.build_cache_bytes(); + + assert_eq!(expected, cache); +} + #[test] fn merkleize_odd() { let data = join(vec![