diff --git a/eth2/types/src/beacon_state.rs b/eth2/types/src/beacon_state.rs index 3d94a8e3de..2d4f463b75 100644 --- a/eth2/types/src/beacon_state.rs +++ b/eth2/types/src/beacon_state.rs @@ -12,7 +12,7 @@ use rand::RngCore; use serde_derive::Serialize; use ssz::{hash, Decodable, DecodeError, Encodable, SszStream, TreeHash}; use std::collections::HashMap; -use swap_or_not_shuffle::get_permutated_index; +use swap_or_not_shuffle::get_permutated_list; pub use builder::BeaconStateBuilder; @@ -420,17 +420,15 @@ impl BeaconState { committees_per_epoch ); - let mut shuffled_active_validator_indices = vec![0; active_validator_indices.len()]; - for (i, _) in active_validator_indices.iter().enumerate() { - let shuffled_i = get_permutated_index( - i, - active_validator_indices.len(), - &seed[..], - spec.shuffle_round_count, - ) - .ok_or_else(|| Error::UnableToShuffle)?; - shuffled_active_validator_indices[i] = active_validator_indices[shuffled_i] - } + let active_validator_indices: Vec = + active_validator_indices.iter().cloned().collect(); + + let shuffled_active_validator_indices = get_permutated_list( + &active_validator_indices, + &seed[..], + spec.shuffle_round_count, + ) + .ok_or_else(|| Error::UnableToShuffle)?; Ok(shuffled_active_validator_indices .honey_badger_split(committees_per_epoch as usize) diff --git a/eth2/utils/swap_or_not_shuffle/benches/benches.rs b/eth2/utils/swap_or_not_shuffle/benches/benches.rs index 1d5b5476cd..b1311b41eb 100644 --- a/eth2/utils/swap_or_not_shuffle/benches/benches.rs +++ b/eth2/utils/swap_or_not_shuffle/benches/benches.rs @@ -1,6 +1,6 @@ use criterion::Criterion; use criterion::{black_box, criterion_group, criterion_main, Benchmark}; -use swap_or_not_shuffle::get_permutated_index; +use swap_or_not_shuffle::{get_permutated_index, get_permutated_list}; const SHUFFLE_ROUND_COUNT: u8 = 90; @@ -48,6 +48,16 @@ fn shuffles(c: &mut Criterion) { .sample_size(10), ); + c.bench( + "_fast_ whole list shuffle", + Benchmark::new("512 elements", move |b| { + let seed = vec![42; 32]; + let list: Vec = (0..512).collect(); + b.iter(|| black_box(get_permutated_list(&list, &seed, SHUFFLE_ROUND_COUNT))) + }) + .sample_size(10), + ); + c.bench( "whole list shuffle", Benchmark::new("16384 elements", move |b| { @@ -56,6 +66,16 @@ fn shuffles(c: &mut Criterion) { }) .sample_size(10), ); + + c.bench( + "_fast_ whole list shuffle", + Benchmark::new("16384 elements", move |b| { + let seed = vec![42; 32]; + let list: Vec = (0..16384).collect(); + b.iter(|| black_box(get_permutated_list(&list, &seed, SHUFFLE_ROUND_COUNT))) + }) + .sample_size(10), + ); } criterion_group!(benches, shuffles,); diff --git a/eth2/utils/swap_or_not_shuffle/src/lib.rs b/eth2/utils/swap_or_not_shuffle/src/lib.rs index 753265f3e7..b9140b269c 100644 --- a/eth2/utils/swap_or_not_shuffle/src/lib.rs +++ b/eth2/utils/swap_or_not_shuffle/src/lib.rs @@ -4,6 +4,36 @@ use int_to_bytes::{int_to_bytes1, int_to_bytes4}; use std::cmp::max; use std::io::Cursor; +pub fn get_permutated_list( + list: &[usize], + seed: &[u8], + shuffle_round_count: u8, +) -> Option> { + let list_size = list.len(); + + if list_size == 0 || list_size > usize::max_value() / 2 || list_size > 2_usize.pow(24) { + return None; + } + + let mut pivots = Vec::with_capacity(shuffle_round_count as usize); + for round in 0..shuffle_round_count { + pivots.push(bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size); + } + + let mut output = Vec::with_capacity(list_size); + + for i in 0..list_size { + let mut index = i; + for round in 0..shuffle_round_count { + let pivot = pivots[round as usize]; + index = do_round(seed, index, pivot, round, list_size)?; + } + output.push(list[index]) + } + + Some(output) +} + /// Return `p(index)` in a pseudorandom permutation `p` of `0...list_size-1` with ``seed`` as entropy. /// /// Utilizes 'swap or not' shuffling found in @@ -32,16 +62,20 @@ pub fn get_permutated_index( let mut index = index; for round in 0..shuffle_round_count { let pivot = bytes_to_int64(&hash_with_round(seed, round)[..]) as usize % list_size; - let flip = (pivot + list_size - index) % list_size; - let position = max(index, flip); - let source = hash_with_round_and_position(seed, round, position)?; - let byte = source[(position % 256) / 8]; - let bit = (byte >> (position % 8)) % 2; - index = if bit == 1 { flip } else { index } + index = do_round(seed, index, pivot, round, list_size)?; } Some(index) } +fn do_round(seed: &[u8], index: usize, pivot: usize, round: u8, list_size: usize) -> Option { + let flip = (pivot + list_size - index) % list_size; + let position = max(index, flip); + let source = hash_with_round_and_position(seed, round, position)?; + let byte = source[(position % 256) / 8]; + let bit = (byte >> (position % 8)) % 2; + Some(if bit == 1 { flip } else { index }) +} + fn hash_with_round_and_position(seed: &[u8], round: u8, position: usize) -> Option> { let mut seed = seed.to_vec(); seed.append(&mut int_to_bytes1(round));