diff --git a/eth2/utils/ssz/src/cached_tree_hash.rs b/eth2/utils/ssz/src/cached_tree_hash.rs index 757bfa9f70..e72ff1ffde 100644 --- a/eth2/utils/ssz/src/cached_tree_hash.rs +++ b/eth2/utils/ssz/src/cached_tree_hash.rs @@ -6,43 +6,29 @@ const BYTES_PER_CHUNK: usize = 32; const HASHSIZE: usize = 32; const MERKLE_HASH_CHUNCK: usize = 2 * BYTES_PER_CHUNK; -pub struct TreeHashCache<'a> { - chunk_offset: usize, - cache: &'a mut [u8], - chunk_modified: &'a mut [bool], - hash_count: &'a mut usize, +pub struct TreeHashCache { + cache: Vec, + chunk_modified: Vec, } -impl<'a> TreeHashCache<'a> { - pub fn build_changes_vec(bytes: &[u8]) -> Vec { - vec![false; bytes.len() / BYTES_PER_CHUNK] +impl Into> for TreeHashCache { + fn into(self) -> Vec { + self.cache } +} - pub fn from_mut_slice( - bytes: &'a mut [u8], - changes: &'a mut [bool], - hash_count: &'a mut usize, - ) -> Option { +impl TreeHashCache { + pub fn from_bytes(bytes: Vec) -> Option { if bytes.len() % BYTES_PER_CHUNK > 0 { return None; } Some(Self { - chunk_offset: 0, + chunk_modified: vec![false; bytes.len() / BYTES_PER_CHUNK], cache: bytes, - chunk_modified: changes, - hash_count, }) } - pub fn increment(&mut self) { - self.chunk_offset += 1 - } - - pub fn modify_current_chunk(&mut self, to: &[u8]) -> Option<()> { - self.modify_chunk(self.chunk_offset, to) - } - pub fn modify_chunk(&mut self, chunk: usize, to: &[u8]) -> Option<()> { let start = chunk * BYTES_PER_CHUNK; let end = start + BYTES_PER_CHUNK; @@ -72,28 +58,6 @@ impl<'a> TreeHashCache<'a> { Some(hash(&self.cache.get(start..end)?)) } - - pub fn just_the_leaves(&mut self, leaves: usize) -> Option { - let nodes = num_nodes(leaves); - let internal = nodes - leaves; - - let leaves_start = (self.chunk_offset + internal) * BYTES_PER_CHUNK; - let leaves_end = leaves_start + leaves * BYTES_PER_CHUNK; - - let modified_start = self.chunk_offset + internal; - let modified_end = modified_start + leaves; - - Some(TreeHashCache { - chunk_offset: 0, - cache: self.cache.get_mut(leaves_start..leaves_end)?, - chunk_modified: self.chunk_modified.get_mut(modified_start..modified_end)?, - hash_count: self.hash_count, - }) - } - - pub fn into_slice(self) -> &'a [u8] { - self.cache - } } fn children(parent: usize) -> (usize, usize) { @@ -107,7 +71,16 @@ fn num_nodes(num_leaves: usize) -> usize { pub trait CachedTreeHash { fn build_cache_bytes(&self) -> Vec; - fn cached_hash_tree_root(&self, other: &Self, cache: &mut TreeHashCache) -> Option<()>; + fn num_bytes(&self) -> usize; + + fn max_num_leaves(&self) -> usize; + + fn cached_hash_tree_root( + &self, + other: &Self, + cache: &mut TreeHashCache, + chunk: usize, + ) -> Option; } impl CachedTreeHash for u64 { @@ -115,15 +88,25 @@ impl CachedTreeHash for u64 { merkleize(&int_to_bytes32(*self)) } - fn cached_hash_tree_root(&self, other: &Self, cache: &mut TreeHashCache) -> Option<()> { + fn num_bytes(&self) -> usize { + 8 + } + + fn max_num_leaves(&self) -> usize { + 1 + } + + fn cached_hash_tree_root( + &self, + other: &Self, + cache: &mut TreeHashCache, + chunk: usize, + ) -> Option { if self != other { - *cache.hash_count += 1; - cache.modify_current_chunk(&merkleize(&int_to_bytes32(*self)))?; + cache.modify_chunk(chunk, &merkleize(&int_to_bytes32(*self)))?; } - cache.increment(); - - Some(()) + Some(chunk + 1) } } @@ -147,28 +130,52 @@ impl CachedTreeHash for Inner { merkleize(&leaves) } - fn cached_hash_tree_root(&self, other: &Self, cache: &mut TreeHashCache) -> Option<()> { - let num_leaves = 4; + fn max_num_leaves(&self) -> usize { + let mut leaves = 0; + leaves += self.a.max_num_leaves(); + leaves += self.b.max_num_leaves(); + leaves += self.c.max_num_leaves(); + leaves += self.d.max_num_leaves(); + leaves + } + fn num_bytes(&self) -> usize { + let mut bytes = 0; + bytes += self.a.num_bytes(); + bytes += self.b.num_bytes(); + bytes += self.c.num_bytes(); + bytes += self.d.num_bytes(); + bytes + } + + fn cached_hash_tree_root( + &self, + other: &Self, + cache: &mut TreeHashCache, + chunk: usize, + ) -> Option { + let num_leaves = self.max_num_leaves(); + let num_nodes = num_nodes(num_leaves); + let num_internal_nodes = num_nodes - num_leaves; + + // Skip past the internal nodes and update any changed leaf nodes. { - let mut leaf_cache = cache.just_the_leaves(num_leaves)?; - self.a.cached_hash_tree_root(&other.a, &mut leaf_cache)?; - self.b.cached_hash_tree_root(&other.b, &mut leaf_cache)?; - self.c.cached_hash_tree_root(&other.c, &mut leaf_cache)?; - self.d.cached_hash_tree_root(&other.d, &mut leaf_cache)?; + let chunk = chunk + num_internal_nodes; + let chunk = self.a.cached_hash_tree_root(&other.a, cache, chunk)?; + let chunk = self.b.cached_hash_tree_root(&other.b, cache, chunk)?; + let chunk = self.c.cached_hash_tree_root(&other.c, cache, chunk)?; + let _chunk = self.d.cached_hash_tree_root(&other.d, cache, chunk)?; } - let nodes = num_nodes(num_leaves); - let internal_chunks = nodes - num_leaves; - - for chunk in (0..internal_chunks).into_iter().rev() { + // Iterate backwards through the internal nodes, rehashing any node where it's children + // have changed. + for chunk in (0..num_internal_nodes).into_iter().rev() { if cache.children_modified(chunk)? { - *cache.hash_count += 1; cache.modify_chunk(chunk, &cache.hash_children(chunk)?)?; } } - Some(()) + Some(chunk + num_nodes) } } @@ -243,18 +250,15 @@ mod tests { _ => panic!("bad index"), }; - let mut changes = TreeHashCache::build_changes_vec(&cache); - let mut hash_count = 0; - let mut cache_struct = - TreeHashCache::from_mut_slice(&mut cache, &mut changes, &mut hash_count).unwrap(); + let mut cache_struct = TreeHashCache::from_bytes(cache.clone()).unwrap(); changed_inner - .cached_hash_tree_root(&inner, &mut cache_struct) + .cached_hash_tree_root(&inner, &mut cache_struct, 0) .unwrap(); - assert_eq!(*cache_struct.hash_count, 3); + // assert_eq!(*cache_struct.hash_count, 3); - let new_cache = cache_struct.into_slice(); + let new_cache: Vec = cache_struct.into(); let data1 = int_to_bytes32(1); let data2 = int_to_bytes32(2);