diff --git a/eth2/types/src/slot_epoch.rs b/eth2/types/src/slot_epoch.rs index eb5a8dced7..ff4fd5b9b3 100644 --- a/eth2/types/src/slot_epoch.rs +++ b/eth2/types/src/slot_epoch.rs @@ -72,7 +72,7 @@ impl Epoch { pub fn slot_iter(&self, epoch_length: u64) -> SlotIter { SlotIter { - current: self.start_slot(epoch_length), + current_iteration: 0, epoch: self, epoch_length, } @@ -80,7 +80,7 @@ impl Epoch { } pub struct SlotIter<'a> { - current: Slot, + current_iteration: u64, epoch: &'a Epoch, epoch_length: u64, } @@ -89,12 +89,13 @@ impl<'a> Iterator for SlotIter<'a> { type Item = Slot; fn next(&mut self) -> Option { - if self.current == self.epoch.end_slot(self.epoch_length) { + if self.current_iteration >= self.epoch_length { None } else { - let previous = self.current; - self.current += 1; - Some(previous) + let start_slot = self.epoch.start_slot(self.epoch_length); + let previous = self.current_iteration; + self.current_iteration += 1; + Some(start_slot + previous) } } } @@ -115,4 +116,22 @@ mod epoch_tests { use ssz::ssz_encode; all_tests!(Epoch); + + #[test] + fn slot_iter() { + let epoch_length = 8; + + let epoch = Epoch::new(0); + + let mut slots = vec![]; + for slot in epoch.slot_iter(epoch_length) { + slots.push(slot); + } + + assert_eq!(slots.len(), epoch_length as usize); + + for i in 0..epoch_length { + assert_eq!(Slot::from(i), slots[i as usize]) + } + } }