Use checked arithmetic in types and state proc (#1009)

This commit is contained in:
Michael Sproul
2020-04-20 12:35:11 +10:00
committed by GitHub
parent 50ef0d7fbf
commit 32074f0d09
49 changed files with 525 additions and 169 deletions

View File

@@ -19,9 +19,7 @@ impl AggregatePublicKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
let pubkey = RawAggregatePublicKey::from_bytes(&bytes).map_err(|_| {
DecodeError::BytesInvalid(
format!("Invalid AggregatePublicKey bytes: {:?}", bytes).to_string(),
)
DecodeError::BytesInvalid(format!("Invalid AggregatePublicKey bytes: {:?}", bytes))
})?;
Ok(AggregatePublicKey(pubkey))

View File

@@ -164,8 +164,8 @@ impl TreeHashCache {
fn lift_dirty(dirty_indices: &[usize]) -> SmallVec8<usize> {
let mut new_dirty = SmallVec8::with_capacity(dirty_indices.len());
for i in 0..dirty_indices.len() {
new_dirty.push(dirty_indices[i] / 2)
for index in dirty_indices {
new_dirty.push(index / 2)
}
new_dirty.dedup();

View File

@@ -89,6 +89,7 @@ impl<T: Encode + Decode> CacheArena<T> {
/// To reiterate, the given `range` should be relative to the given `alloc_id`, not
/// `self.backing`. E.g., if the allocation has an offset of `20` and the range is `0..1`, then
/// the splice will translate to `self.backing[20..21]`.
#[allow(clippy::comparison_chain)]
fn splice_forgetful<I: IntoIterator<Item = T>>(
&mut self,
alloc_id: usize,

View File

@@ -8,6 +8,7 @@ edition = "2018"
ethereum-types = "0.8.0"
eth2_hashing = "0.1.0"
lazy_static = "1.4.0"
safe_arith = { path = "../safe_arith" }
[dev-dependencies]
quickcheck = "0.9.0"

View File

@@ -1,6 +1,7 @@
use eth2_hashing::{hash, hash32_concat, ZERO_HASHES};
use ethereum_types::H256;
use lazy_static::lazy_static;
use safe_arith::ArithError;
const MAX_TREE_DEPTH: usize = 32;
const EMPTY_SLICE: &[H256] = &[];
@@ -38,6 +39,8 @@ pub enum MerkleTreeError {
Invalid,
// Incorrect Depth provided
DepthTooSmall,
// Overflow occurred
ArithError,
}
impl MerkleTree {
@@ -232,6 +235,12 @@ fn merkle_root_from_branch(leaf: H256, branch: &[H256], depth: usize, index: usi
H256::from_slice(&merkle_root)
}
impl From<ArithError> for MerkleTreeError {
fn from(_: ArithError) -> Self {
MerkleTreeError::ArithError
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -0,0 +1,9 @@
[package]
name = "safe_arith"
version = "0.1.0"
authors = ["Michael Sproul <michael@sigmaprime.io>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@@ -0,0 +1,161 @@
//! Library for safe arithmetic on integers, avoiding overflow and division by zero.
/// Error representing the failure of an arithmetic operation.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ArithError {
Overflow,
DivisionByZero,
}
type Result<T> = std::result::Result<T, ArithError>;
macro_rules! assign_method {
($name:ident, $op:ident, $doc_op:expr) => {
assign_method!($name, $op, Self, $doc_op);
};
($name:ident, $op:ident, $rhs_ty:ty, $doc_op:expr) => {
#[doc = "Safe variant of `"]
#[doc = $doc_op]
#[doc = "`."]
fn $name(&mut self, other: $rhs_ty) -> Result<()> {
*self = self.$op(other)?;
Ok(())
}
};
}
/// Trait providing safe arithmetic operations for built-in types.
pub trait SafeArith: Sized + Copy {
const ZERO: Self;
const ONE: Self;
/// Safe variant of `+` that guards against overflow.
fn safe_add(&self, other: Self) -> Result<Self>;
/// Safe variant of `-` that guards against overflow.
fn safe_sub(&self, other: Self) -> Result<Self>;
/// Safe variant of `*` that guards against overflow.
fn safe_mul(&self, other: Self) -> Result<Self>;
/// Safe variant of `/` that guards against division by 0.
fn safe_div(&self, other: Self) -> Result<Self>;
/// Safe variant of `%` that guards against division by 0.
fn safe_rem(&self, other: Self) -> Result<Self>;
/// Safe variant of `<<` that guards against overflow.
fn safe_shl(&self, other: u32) -> Result<Self>;
/// Safe variant of `>>` that guards against overflow.
fn safe_shr(&self, other: u32) -> Result<Self>;
assign_method!(safe_add_assign, safe_add, "+=");
assign_method!(safe_sub_assign, safe_sub, "-=");
assign_method!(safe_mul_assign, safe_mul, "*=");
assign_method!(safe_div_assign, safe_div, "/=");
assign_method!(safe_rem_assign, safe_rem, "%=");
assign_method!(safe_shl_assign, safe_shl, u32, "<<=");
assign_method!(safe_shr_assign, safe_shr, u32, ">>=");
/// Mutate `self` by adding 1, erroring on overflow.
fn increment(&mut self) -> Result<()> {
self.safe_add_assign(Self::ONE)
}
}
macro_rules! impl_safe_arith {
($typ:ty) => {
impl SafeArith for $typ {
const ZERO: Self = 0;
const ONE: Self = 1;
fn safe_add(&self, other: Self) -> Result<Self> {
self.checked_add(other).ok_or(ArithError::Overflow)
}
fn safe_sub(&self, other: Self) -> Result<Self> {
self.checked_sub(other).ok_or(ArithError::Overflow)
}
fn safe_mul(&self, other: Self) -> Result<Self> {
self.checked_mul(other).ok_or(ArithError::Overflow)
}
fn safe_div(&self, other: Self) -> Result<Self> {
self.checked_div(other).ok_or(ArithError::DivisionByZero)
}
fn safe_rem(&self, other: Self) -> Result<Self> {
self.checked_rem(other).ok_or(ArithError::DivisionByZero)
}
fn safe_shl(&self, other: u32) -> Result<Self> {
self.checked_shl(other).ok_or(ArithError::Overflow)
}
fn safe_shr(&self, other: u32) -> Result<Self> {
self.checked_shr(other).ok_or(ArithError::Overflow)
}
}
};
}
impl_safe_arith!(u8);
impl_safe_arith!(u16);
impl_safe_arith!(u32);
impl_safe_arith!(u64);
impl_safe_arith!(usize);
impl_safe_arith!(i8);
impl_safe_arith!(i16);
impl_safe_arith!(i32);
impl_safe_arith!(i64);
impl_safe_arith!(isize);
#[cfg(test)]
mod test {
use super::*;
#[test]
fn basic() {
let x = 10u32;
let y = 11;
assert_eq!(x.safe_add(y), Ok(x + y));
assert_eq!(y.safe_sub(x), Ok(y - x));
assert_eq!(x.safe_mul(y), Ok(x * y));
assert_eq!(x.safe_div(y), Ok(x / y));
assert_eq!(x.safe_rem(y), Ok(x % y));
assert_eq!(x.safe_shl(1), Ok(x << 1));
assert_eq!(x.safe_shr(1), Ok(x >> 1));
}
#[test]
fn mutate() {
let mut x = 0u8;
x.increment().unwrap();
x.increment().unwrap();
assert_eq!(x, 2);
x.safe_sub_assign(1).unwrap();
assert_eq!(x, 1);
x.safe_shl_assign(1).unwrap();
assert_eq!(x, 2);
x.safe_mul_assign(3).unwrap();
assert_eq!(x, 6);
x.safe_div_assign(4).unwrap();
assert_eq!(x, 1);
x.safe_shr_assign(1).unwrap();
assert_eq!(x, 0);
}
#[test]
fn errors() {
assert!(u32::max_value().safe_add(1).is_err());
assert!(u32::min_value().safe_sub(1).is_err());
assert!(u32::max_value().safe_mul(2).is_err());
assert!(u32::max_value().safe_div(0).is_err());
assert!(u32::max_value().safe_rem(0).is_err());
assert!(u32::max_value().safe_shl(32).is_err());
assert!(u32::max_value().safe_shr(32).is_err());
}
}

View File

@@ -1,4 +1,3 @@
use hex;
use hex::ToHex;
use serde::de::{self, Visitor};
use std::fmt;

View File

@@ -86,6 +86,7 @@ pub fn ssz_encode_derive(input: TokenStream) -> TokenStream {
let field_types_f = field_types_a.clone();
let output = quote! {
#[allow(clippy::integer_arithmetic)]
impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
fn is_ssz_fixed_len() -> bool {
#(
@@ -221,6 +222,7 @@ pub fn ssz_decode_derive(input: TokenStream) -> TokenStream {
}
let output = quote! {
#[allow(clippy::integer_arithmetic)]
impl #impl_generics ssz::Decode for #name #ty_generics #where_clause {
fn is_ssz_fixed_len() -> bool {
#(

View File

@@ -274,22 +274,20 @@ impl MerkleHasher {
loop {
if let Some(root) = self.root {
break Ok(root);
} else if let Some(node) = self.half_nodes.last() {
let right_child = node.id * 2 + 1;
self.process_right_node(right_child, self.zero_hash(right_child));
} else if self.next_leaf == 1 {
// The next_leaf can only be 1 if the tree has a depth of one. If have been no
// leaves supplied, assume a root of zero.
break Ok(Hash256::zero());
} else {
if let Some(node) = self.half_nodes.last() {
let right_child = node.id * 2 + 1;
self.process_right_node(right_child, self.zero_hash(right_child));
} else if self.next_leaf == 1 {
// The next_leaf can only be 1 if the tree has a depth of one. If have been no
// leaves supplied, assume a root of zero.
break Ok(Hash256::zero());
} else {
// The only scenario where there are (a) no half nodes and (b) a tree of depth
// two or more is where no leaves have been supplied at all.
//
// Once we supply this first zero-hash leaf then all future operations will be
// triggered via the `process_right_node` branch.
self.process_left_node(self.next_leaf, self.zero_hash(self.next_leaf))
}
// The only scenario where there are (a) no half nodes and (b) a tree of depth
// two or more is where no leaves have been supplied at all.
//
// Once we supply this first zero-hash leaf then all future operations will be
// triggered via the `process_right_node` branch.
self.process_left_node(self.next_leaf, self.zero_hash(self.next_leaf))
}
}
}