diff --git a/validator_manager/src/move_validators.rs b/validator_manager/src/move_validators.rs index 592f6e0dd9..ab1db3c2cd 100644 --- a/validator_manager/src/move_validators.rs +++ b/validator_manager/src/move_validators.rs @@ -561,6 +561,7 @@ mod test { dest_import_builder: Option, duplicates: usize, dir: TempDir, + move_back_again: bool, } impl TestBuilder { @@ -571,9 +572,15 @@ mod test { dest_import_builder: None, duplicates: 0, dir, + move_back_again: false, } } + fn move_back_again(mut self) -> Self { + self.move_back_again = true; + self + } + async fn with_src_validators(mut self, count: u32, first_index: u32) -> Self { let builder = ImportTestBuilder::new() .await @@ -597,18 +604,15 @@ mod test { self } - async fn run_test(self, gen_validators_enum: F) -> TestResult + async fn move_validators( + &self, + gen_validators_enum: F, + src_vc: &ApiTester, + dest_vc: &ApiTester, + ) -> Result<(), String> where F: Fn(&[PublicKeyBytes]) -> Validators, { - let src_vc = if let Some(import_builder) = self.src_import_builder { - let import_test_result = import_builder.run_test().await; - assert!(import_test_result.result.is_ok()); - import_test_result.vc - } else { - ApiTester::new().await - }; - let src_vc_token_path = self.dir.path().join(SRC_VC_TOKEN_FILE_NAME); fs::write(&src_vc_token_path, &src_vc.api_token).unwrap(); let (src_vc_client, src_vc_initial_keystores) = @@ -622,13 +626,6 @@ mod test { .collect(); let validators = gen_validators_enum(&src_vc_initial_pubkeys); - let dest_vc = if let Some(import_builder) = self.dest_import_builder { - let import_test_result = import_builder.run_test().await; - assert!(import_test_result.result.is_ok()); - import_test_result.vc - } else { - ApiTester::new().await - }; let dest_vc_token_path = self.dir.path().join(DEST_VC_TOKEN_FILE_NAME); fs::write(&dest_vc_token_path, &dest_vc.api_token).unwrap(); @@ -729,6 +726,39 @@ mod test { } } + result + } + + async fn run_test(mut self, gen_validators_enum: F) -> TestResult + where + F: Fn(&[PublicKeyBytes]) -> Validators + Copy, + { + let src_vc = if let Some(import_builder) = self.src_import_builder.take() { + let import_test_result = import_builder.run_test().await; + assert!(import_test_result.result.is_ok()); + import_test_result.vc + } else { + ApiTester::new().await + }; + + let dest_vc = if let Some(import_builder) = self.dest_import_builder.take() { + let import_test_result = import_builder.run_test().await; + assert!(import_test_result.result.is_ok()); + import_test_result.vc + } else { + ApiTester::new().await + }; + + let result = self + .move_validators(gen_validators_enum, &src_vc, &dest_vc) + .await; + + if self.move_back_again { + self.move_validators(gen_validators_enum, &dest_vc, &src_vc) + .await + .unwrap(); + } + TestResult { result } } } @@ -900,4 +930,16 @@ mod test { .await .assert_err(); } + + #[tokio::test] + async fn two_validator_move_all_and_back_again() { + TestBuilder::new() + .await + .with_src_validators(2, 0) + .await + .move_back_again() + .run_test(|_| Validators::All) + .await + .assert_ok(); + } }