diff --git a/beacon_node/lighthouse_network/src/rpc/handler.rs b/beacon_node/lighthouse_network/src/rpc/handler.rs index 720895bbe7..9861119ac1 100644 --- a/beacon_node/lighthouse_network/src/rpc/handler.rs +++ b/beacon_node/lighthouse_network/src/rpc/handler.rs @@ -13,7 +13,8 @@ use futures::prelude::*; use libp2p::PeerId; use libp2p::swarm::handler::{ ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError, - FullyNegotiatedInbound, FullyNegotiatedOutbound, StreamUpgradeError, SubstreamProtocol, + FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, StreamUpgradeError, + SubstreamProtocol, }; use libp2p::swarm::{ConnectionId, Stream}; use logging::crit; @@ -888,6 +889,16 @@ where ConnectionEvent::DialUpgradeError(DialUpgradeError { info, error }) => { self.on_dial_upgrade_error(info, error) } + ConnectionEvent::ListenUpgradeError(ListenUpgradeError { + error: (proto, error), + .. + }) => { + self.events_out.push(HandlerEvent::Err(HandlerErr::Inbound { + id: self.current_inbound_substream_id, + proto, + error, + })); + } _ => { // NOTE: ConnectionEvent is a non exhaustive enum so updates should be based on // release notes more than compiler feedback @@ -924,7 +935,7 @@ where request.count() )), })); - return self.shutdown(None); + return; } } RequestType::BlobsByRange(request) => { @@ -940,7 +951,7 @@ where max_allowed, max_requested_blobs )), })); - return self.shutdown(None); + return; } } _ => {} diff --git a/beacon_node/lighthouse_network/src/rpc/protocol.rs b/beacon_node/lighthouse_network/src/rpc/protocol.rs index f0ac9d00f9..34d8efccd1 100644 --- a/beacon_node/lighthouse_network/src/rpc/protocol.rs +++ b/beacon_node/lighthouse_network/src/rpc/protocol.rs @@ -675,7 +675,7 @@ where E: EthSpec, { type Output = InboundOutput; - type Error = RPCError; + type Error = (Protocol, RPCError); type Future = BoxFuture<'static, Result>; fn upgrade_inbound(self, socket: TSocket, protocol: ProtocolId) -> Self::Future { @@ -717,10 +717,12 @@ where ) .await { - Err(e) => Err(RPCError::from(e)), + Err(e) => Err((versioned_protocol.protocol(), RPCError::from(e))), Ok((Some(Ok(request)), stream)) => Ok((request, stream)), - Ok((Some(Err(e)), _)) => Err(e), - Ok((None, _)) => Err(RPCError::IncompleteStream), + Ok((Some(Err(e)), _)) => Err((versioned_protocol.protocol(), e)), + Ok((None, _)) => { + Err((versioned_protocol.protocol(), RPCError::IncompleteStream)) + } } } } diff --git a/beacon_node/lighthouse_network/tests/rpc_tests.rs b/beacon_node/lighthouse_network/tests/rpc_tests.rs index 53939687d3..debe30b34f 100644 --- a/beacon_node/lighthouse_network/tests/rpc_tests.rs +++ b/beacon_node/lighthouse_network/tests/rpc_tests.rs @@ -5,8 +5,12 @@ use crate::common::spec_with_all_forks_enabled; use crate::common::{Protocol, build_tracing_subscriber}; use bls::Signature; use fixed_bytes::FixedBytesExtended; +use libp2p::PeerId; use lighthouse_network::rpc::{RequestType, methods::*}; -use lighthouse_network::service::api_types::AppRequestId; +use lighthouse_network::service::api_types::{ + AppRequestId, BlobsByRangeRequestId, BlocksByRangeRequestId, ComponentsByRangeRequestId, + DataColumnsByRangeRequestId, DataColumnsByRangeRequester, RangeRequestId, SyncRequestId, +}; use lighthouse_network::{NetworkEvent, ReportSource, Response}; use ssz::Encode; use ssz_types::{RuntimeVariableList, VariableList}; @@ -1783,3 +1787,157 @@ fn test_active_requests() { } }) } + +// Test that when a node receives an invalid BlocksByRange request exceeding the maximum count, +// it bans the sender. +#[test] +fn test_request_too_large_blocks_by_range() { + let spec = Arc::new(spec_with_all_forks_enabled()); + + test_request_too_large( + AppRequestId::Sync(SyncRequestId::BlocksByRange(BlocksByRangeRequestId { + id: 1, + parent_request_id: ComponentsByRangeRequestId { + id: 1, + requester: RangeRequestId::RangeSync { + chain_id: 1, + batch_id: Epoch::new(1), + }, + }, + })), + RequestType::BlocksByRange(OldBlocksByRangeRequest::new( + 0, + spec.max_request_blocks(ForkName::Base) as u64 + 1, // exceeds the max request defined in the spec. + 1, + )), + ); +} + +// Test that when a node receives an invalid BlobsByRange request exceeding the maximum count, +// it bans the sender. +#[test] +fn test_request_too_large_blobs_by_range() { + let spec = Arc::new(spec_with_all_forks_enabled()); + + let max_request_blobs_count = spec.max_request_blob_sidecars(ForkName::Base) as u64 + / spec.max_blobs_per_block_within_fork(ForkName::Base); + test_request_too_large( + AppRequestId::Sync(SyncRequestId::BlobsByRange(BlobsByRangeRequestId { + id: 1, + parent_request_id: ComponentsByRangeRequestId { + id: 1, + requester: RangeRequestId::RangeSync { + chain_id: 1, + batch_id: Epoch::new(1), + }, + }, + })), + RequestType::BlobsByRange(BlobsByRangeRequest { + start_slot: 0, + count: max_request_blobs_count + 1, // exceeds the max request defined in the spec. + }), + ); +} + +// Test that when a node receives an invalid DataColumnsByRange request exceeding the columns count, +// it bans the sender. +#[test] +fn test_request_too_large_data_columns_by_range() { + test_request_too_large( + AppRequestId::Sync(SyncRequestId::DataColumnsByRange( + DataColumnsByRangeRequestId { + id: 1, + parent_request_id: DataColumnsByRangeRequester::ComponentsByRange( + ComponentsByRangeRequestId { + id: 1, + requester: RangeRequestId::RangeSync { + chain_id: 1, + batch_id: Epoch::new(1), + }, + }, + ), + peer: PeerId::random(), + }, + )), + RequestType::DataColumnsByRange(DataColumnsByRangeRequest { + start_slot: 0, + count: 0, + // exceeds the max request defined in the spec. + columns: vec![0; E::number_of_columns() + 1], + }), + ); +} + +fn test_request_too_large(app_request_id: AppRequestId, request: RequestType) { + // Set up the logging. + let log_level = "debug"; + let enable_logging = true; + let _subscriber = build_tracing_subscriber(log_level, enable_logging); + let rt = Arc::new(Runtime::new().unwrap()); + let spec = Arc::new(spec_with_all_forks_enabled()); + + rt.block_on(async { + let (mut sender, mut receiver) = common::build_node_pair( + Arc::downgrade(&rt), + ForkName::Base, + spec, + Protocol::Tcp, + false, + None, + ) + .await; + + // Build the sender future + let sender_future = async { + loop { + match sender.next_event().await { + NetworkEvent::PeerConnectedOutgoing(peer_id) => { + debug!(?request, %peer_id, "Sending RPC request"); + sender + .send_request(peer_id, app_request_id, request.clone()) + .unwrap(); + } + NetworkEvent::ResponseReceived { + app_request_id, + response, + .. + } => { + debug!(?app_request_id, ?response, "Received response"); + } + NetworkEvent::RPCFailed { error, .. } => { + // This variant should be unreachable, as the receiver doesn't respond with an error when a request exceeds the limit. + debug!(?error, "RPC failed"); + unreachable!(); + } + NetworkEvent::PeerDisconnected(peer_id) => { + // The receiver should disconnect as a result of the invalid request. + debug!(%peer_id, "Peer disconnected"); + // End the test. + return; + } + _ => {} + } + } + } + .instrument(info_span!("Sender")); + + // Build the receiver future + let receiver_future = async { + loop { + if let NetworkEvent::RequestReceived { .. } = receiver.next_event().await { + // This event should be unreachable, as the handler drops the invalid request. + unreachable!(); + } + } + } + .instrument(info_span!("Receiver")); + + tokio::select! { + _ = sender_future => {} + _ = receiver_future => {} + _ = sleep(Duration::from_secs(30)) => { + panic!("Future timed out"); + } + } + }); +}