diff --git a/ssz/src/decode.rs b/ssz/src/decode.rs index 0c6a3db347..15c053bcf0 100644 --- a/ssz/src/decode.rs +++ b/ssz/src/decode.rs @@ -36,11 +36,11 @@ pub fn decode_ssz_list(ssz_bytes: &[u8], index: usize) { if index + LENGTH_BYTES > ssz_bytes.len() { - return Err(DecodeError::OutOfBounds); + return Err(DecodeError::TooShort); }; // get the length - let mut serialized_length = match decode_length(ssz_bytes, LENGTH_BYTES) { + let serialized_length = match decode_length(ssz_bytes, index, LENGTH_BYTES) { Err(v) => return Err(v), Ok(v) => v, }; @@ -48,7 +48,7 @@ pub fn decode_ssz_list(ssz_bytes: &[u8], index: usize) let final_len: usize = index + LENGTH_BYTES + serialized_length; if final_len > ssz_bytes.len() { - return Err(DecodeError::OutOfBounds); + return Err(DecodeError::TooShort); }; let mut tmp_index = index + LENGTH_BYTES; @@ -71,15 +71,15 @@ pub fn decode_ssz_list(ssz_bytes: &[u8], index: usize) /// Given some number of bytes, interpret the first four /// bytes as a 32-bit big-endian integer and return the /// result. -fn decode_length(bytes: &[u8], length_bytes: usize) +fn decode_length(bytes: &[u8], index: usize, length_bytes: usize) -> Result { if bytes.len() < length_bytes { return Err(DecodeError::TooShort); }; let mut len: usize = 0; - for i in 0..length_bytes { - let offset = (length_bytes - i - 1) * 8; + for i in index..index+length_bytes { + let offset = (index+length_bytes - i - 1) * 8; len = ((bytes[i] as usize) << offset) | len; }; Ok(len) @@ -94,21 +94,25 @@ mod tests { fn test_ssz_decode_length() { let decoded = decode_length( &vec![0, 0, 0, 1], + 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 1); let decoded = decode_length( &vec![0, 0, 1, 0], + 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 256); let decoded = decode_length( &vec![0, 0, 1, 255], + 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 511); let decoded = decode_length( &vec![255, 255, 255, 255], + 0, LENGTH_BYTES); assert_eq!(decoded.unwrap(), 4294967295); } @@ -125,8 +129,82 @@ mod tests { for i in params { let decoded = decode_length( &encode_length(i, LENGTH_BYTES), + 0, LENGTH_BYTES).unwrap(); assert_eq!(i, decoded); } } + + #[test] + fn test_decode_ssz_list() { + // u16 + let v: Vec = vec![10, 10, 10, 10]; + let decoded: (Vec, usize) = decode_ssz_list( + &vec![0, 0, 0, 8, 0, 10, 0, 10, 0, 10, 0, 10], + 0 + ).unwrap(); + + assert_eq!(decoded.0, v); + assert_eq!(decoded.1, 12); + + // u32 + let v: Vec = vec![10, 10, 10, 10]; + let decoded: (Vec, usize) = decode_ssz_list( + &vec![ + 0, 0, 0, 16, + 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 0, 10, 0, 0, 0, 10 + ], + 0 + ).unwrap(); + assert_eq!(decoded.0, v); + assert_eq!(decoded.1, 20); + + + // u64 + let v: Vec = vec![10,10,10,10]; + let decoded: (Vec, usize) = decode_ssz_list( + &vec![0, 0, 0, 32, + 0, 0, 0, 0, 0, 0, 0, 10, + 0, 0, 0, 0, 0, 0, 0, 10, + 0, 0, 0, 0, 0, 0, 0, 10, + 0, 0, 0, 0, 0, 0, 0, 10, + ], + 0 + ).unwrap(); + assert_eq!(decoded.0, v); + assert_eq!(decoded.1, 36); + + // Check that it can accept index + let v: Vec = vec![15,15,15,15]; + let decoded: (Vec, usize) = decode_ssz_list( + &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 0, 0, 0, 32, + 0, 0, 0, 0, 0, 0, 0, 15, + 0, 0, 0, 0, 0, 0, 0, 15, + 0, 0, 0, 0, 0, 0, 0, 15, + 0, 0, 0, 0, 0, 0, 0, 15, + ], + 10 + ).unwrap(); + assert_eq!(decoded.0, v); + assert_eq!(decoded.1, 46); + + // Check that length > bytes throws error + let decoded: Result<(Vec, usize), DecodeError> = decode_ssz_list( + &vec![0, 0, 0, 32, + 0, 0, 0, 0, 0, 0, 0, 15, + ], + 0 + ); + assert_eq!(decoded, Err(DecodeError::TooShort)); + + // Check that incorrect index throws error + let decoded: Result<(Vec, usize), DecodeError> = decode_ssz_list( + &vec![ + 0, 0, 0, 0, 0, 0, 0, 15, + ], + 16 + ); + assert_eq!(decoded, Err(DecodeError::TooShort)); + } }