Finished first round of fighting RPC types

This commit is contained in:
Age Manning
2020-05-01 20:05:03 +10:00
parent 08838fca23
commit 9e6ae448a6
5 changed files with 93 additions and 61 deletions

View File

@@ -19,9 +19,9 @@ pub trait OutboundCodec<TItem>: Encoder<TItem> + Decoder {
/* Global Inbound Codec */ /* Global Inbound Codec */
// This deals with Decoding RPC Requests from other peers and encoding our responses // This deals with Decoding RPC Requests from other peers and encoding our responses
pub struct BaseInboundCodec<TCodec, TSpec, TItem> pub struct BaseInboundCodec<TCodec, TSpec>
where where
TCodec: Encoder<TItem> + Decoder, TCodec: Encoder<RPCErrorResponse<TSpec>> + Decoder,
TSpec: EthSpec, TSpec: EthSpec,
{ {
/// Inner codec for handling various encodings /// Inner codec for handling various encodings
@@ -29,9 +29,9 @@ where
phantom: PhantomData<TSpec>, phantom: PhantomData<TSpec>,
} }
impl<TCodec, TSpec, TItem> BaseInboundCodec<TCodec, TSpec, TItem> impl<TCodec, TSpec> BaseInboundCodec<TCodec, TSpec>
where where
TCodec: Encoder<TItem> + Decoder, TCodec: Encoder<RPCErrorResponse<TSpec>> + Decoder,
TSpec: EthSpec, TSpec: EthSpec,
{ {
pub fn new(codec: TCodec) -> Self { pub fn new(codec: TCodec) -> Self {
@@ -44,9 +44,9 @@ where
/* Global Outbound Codec */ /* Global Outbound Codec */
// This deals with Decoding RPC Responses from other peers and encoding our requests // This deals with Decoding RPC Responses from other peers and encoding our requests
pub struct BaseOutboundCodec<TOutboundCodec, TSpec, TItem> pub struct BaseOutboundCodec<TOutboundCodec, TSpec>
where where
TOutboundCodec: OutboundCodec<TItem>, TOutboundCodec: OutboundCodec<RPCRequest<TSpec>>,
TSpec: EthSpec, TSpec: EthSpec,
{ {
/// Inner codec for handling various encodings. /// Inner codec for handling various encodings.
@@ -56,10 +56,10 @@ where
phantom: PhantomData<TSpec>, phantom: PhantomData<TSpec>,
} }
impl<TOutboundCodec, TSpec, TItem> BaseOutboundCodec<TOutboundCodec, TSpec, TItem> impl<TOutboundCodec, TSpec> BaseOutboundCodec<TOutboundCodec, TSpec>
where where
TSpec: EthSpec, TSpec: EthSpec,
TOutboundCodec: OutboundCodec<TItem>, TOutboundCodec: OutboundCodec<RPCRequest<TSpec>>,
{ {
pub fn new(codec: TOutboundCodec) -> Self { pub fn new(codec: TOutboundCodec) -> Self {
BaseOutboundCodec { BaseOutboundCodec {
@@ -75,8 +75,7 @@ where
/* Base Inbound Codec */ /* Base Inbound Codec */
// This Encodes RPC Responses sent to external peers // This Encodes RPC Responses sent to external peers
impl<TCodec, TSpec> Encoder<RPCErrorResponse<TSpec>> impl<TCodec, TSpec> Encoder<RPCErrorResponse<TSpec>> for BaseInboundCodec<TCodec, TSpec>
for BaseInboundCodec<TCodec, TSpec, RPCErrorResponse<TSpec>>
where where
TSpec: EthSpec, TSpec: EthSpec,
TCodec: Decoder + Encoder<RPCErrorResponse<TSpec>>, TCodec: Decoder + Encoder<RPCErrorResponse<TSpec>>,
@@ -100,7 +99,7 @@ where
// This Decodes RPC Requests from external peers // This Decodes RPC Requests from external peers
// TODO: check if the Item parameter is correct // TODO: check if the Item parameter is correct
impl<TCodec, TSpec> Decoder for BaseInboundCodec<TCodec, TSpec, RPCErrorResponse<TSpec>> impl<TCodec, TSpec> Decoder for BaseInboundCodec<TCodec, TSpec>
where where
TSpec: EthSpec, TSpec: EthSpec,
// TODO: check if the Item parameter is correct // TODO: check if the Item parameter is correct
@@ -117,8 +116,7 @@ where
/* Base Outbound Codec */ /* Base Outbound Codec */
// This Encodes RPC Requests sent to external peers // This Encodes RPC Requests sent to external peers
impl<TCodec, TSpec> Encoder<RPCRequest<TSpec>> impl<TCodec, TSpec> Encoder<RPCRequest<TSpec>> for BaseOutboundCodec<TCodec, TSpec>
for BaseOutboundCodec<TCodec, TSpec, RPCRequest<TSpec>>
where where
TSpec: EthSpec, TSpec: EthSpec,
TCodec: OutboundCodec<RPCRequest<TSpec>> + Encoder<RPCRequest<TSpec>>, TCodec: OutboundCodec<RPCRequest<TSpec>> + Encoder<RPCRequest<TSpec>>,
@@ -131,7 +129,7 @@ where
} }
// This decodes RPC Responses received from external peers // This decodes RPC Responses received from external peers
impl<TCodec, TSpec> Decoder for BaseOutboundCodec<TCodec, TSpec, RPCRequest<TSpec>> impl<TCodec, TSpec> Decoder for BaseOutboundCodec<TCodec, TSpec>
where where
TSpec: EthSpec, TSpec: EthSpec,
TCodec: OutboundCodec<RPCRequest<TSpec>, ErrorType = ErrorMessage> TCodec: OutboundCodec<RPCRequest<TSpec>, ErrorType = ErrorMessage>

View File

@@ -12,17 +12,17 @@ use tokio_util::codec::{Decoder, Encoder};
use types::EthSpec; use types::EthSpec;
// Known types of codecs // Known types of codecs
pub enum InboundCodec<TSpec: EthSpec, TItem> { pub enum InboundCodec<TSpec: EthSpec> {
SSZSnappy(BaseInboundCodec<SSZSnappyInboundCodec<TSpec>, TSpec, TItem>), SSZSnappy(BaseInboundCodec<SSZSnappyInboundCodec<TSpec>, TSpec>),
SSZ(BaseInboundCodec<SSZInboundCodec<TSpec>, TSpec, TItem>), SSZ(BaseInboundCodec<SSZInboundCodec<TSpec>, TSpec>),
} }
pub enum OutboundCodec<TSpec: EthSpec, TItem> { pub enum OutboundCodec<TSpec: EthSpec> {
SSZSnappy(BaseOutboundCodec<SSZSnappyOutboundCodec<TSpec>, TSpec, TItem>), SSZSnappy(BaseOutboundCodec<SSZSnappyOutboundCodec<TSpec>, TSpec>),
SSZ(BaseOutboundCodec<SSZOutboundCodec<TSpec>, TSpec, TItem>), SSZ(BaseOutboundCodec<SSZOutboundCodec<TSpec>, TSpec>),
} }
impl<T: EthSpec> Encoder<RPCErrorResponse<T>> for InboundCodec<T, RPCErrorResponse<T>> { impl<T: EthSpec> Encoder<RPCErrorResponse<T>> for InboundCodec<T> {
type Error = RPCError; type Error = RPCError;
fn encode(&mut self, item: RPCErrorResponse<T>, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, item: RPCErrorResponse<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
@@ -33,7 +33,7 @@ impl<T: EthSpec> Encoder<RPCErrorResponse<T>> for InboundCodec<T, RPCErrorRespon
} }
} }
impl<TSpec: EthSpec> Decoder for InboundCodec<TSpec, RPCErrorResponse<TSpec>> { impl<TSpec: EthSpec> Decoder for InboundCodec<TSpec> {
type Item = RPCRequest<TSpec>; type Item = RPCRequest<TSpec>;
type Error = RPCError; type Error = RPCError;
@@ -45,7 +45,7 @@ impl<TSpec: EthSpec> Decoder for InboundCodec<TSpec, RPCErrorResponse<TSpec>> {
} }
} }
impl<TSpec: EthSpec> Encoder<RPCRequest<TSpec>> for OutboundCodec<TSpec, RPCRequest<TSpec>> { impl<TSpec: EthSpec> Encoder<RPCRequest<TSpec>> for OutboundCodec<TSpec> {
type Error = RPCError; type Error = RPCError;
fn encode(&mut self, item: RPCRequest<TSpec>, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, item: RPCRequest<TSpec>, dst: &mut BytesMut) -> Result<(), Self::Error> {
@@ -56,7 +56,7 @@ impl<TSpec: EthSpec> Encoder<RPCRequest<TSpec>> for OutboundCodec<TSpec, RPCRequ
} }
} }
impl<T: EthSpec> Decoder for OutboundCodec<T, RPCRequest<T>> { impl<T: EthSpec> Decoder for OutboundCodec<T> {
type Item = RPCErrorResponse<T>; type Item = RPCErrorResponse<T>;
type Error = RPCError; type Error = RPCError;

View File

@@ -19,7 +19,7 @@ pub struct SSZInboundCodec<TSpec: EthSpec> {
phantom: PhantomData<TSpec>, phantom: PhantomData<TSpec>,
} }
impl<T: EthSpec> SSZInboundCodec<T> { impl<TSpec: EthSpec> SSZInboundCodec<TSpec> {
pub fn new(protocol: ProtocolId, max_packet_size: usize) -> Self { pub fn new(protocol: ProtocolId, max_packet_size: usize) -> Self {
let mut uvi_codec = UviBytes::default(); let mut uvi_codec = UviBytes::default();
uvi_codec.set_max_len(max_packet_size); uvi_codec.set_max_len(max_packet_size);
@@ -39,7 +39,11 @@ impl<T: EthSpec> SSZInboundCodec<T> {
impl<TSpec: EthSpec> Encoder<RPCErrorResponse<TSpec>> for SSZInboundCodec<TSpec> { impl<TSpec: EthSpec> Encoder<RPCErrorResponse<TSpec>> for SSZInboundCodec<TSpec> {
type Error = RPCError; type Error = RPCError;
fn encode(&mut self, item: RPCErrorResponse<TSpec>, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(
&mut self,
item: RPCErrorResponse<TSpec>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
let bytes = match item { let bytes = match item {
RPCErrorResponse::Success(resp) => match resp { RPCErrorResponse::Success(resp) => match resp {
RPCResponse::Status(res) => res.as_ssz_bytes(), RPCResponse::Status(res) => res.as_ssz_bytes(),

View File

@@ -256,7 +256,7 @@ where
fn inject_fully_negotiated_inbound( fn inject_fully_negotiated_inbound(
&mut self, &mut self,
substream: <RPCProtocol<TSpec> as InboundUpgrade<NegotiatedSubstream>>::Output, substream: <Self::InboundProtocol as InboundUpgrade<NegotiatedSubstream>>::Output,
) { ) {
// update the keep alive timeout if there are no more remaining outbound streams // update the keep alive timeout if there are no more remaining outbound streams
if let KeepAlive::Until(_) = self.keep_alive { if let KeepAlive::Until(_) = self.keep_alive {
@@ -288,7 +288,7 @@ where
fn inject_fully_negotiated_outbound( fn inject_fully_negotiated_outbound(
&mut self, &mut self,
out: <RPCRequest<TSpec> as OutboundUpgrade<NegotiatedSubstream>>::Output, out: <Self::OutboundProtocol as OutboundUpgrade<NegotiatedSubstream>>::Output,
rpc_event: Self::OutboundOpenInfo, rpc_event: Self::OutboundOpenInfo,
) { ) {
self.dial_negotiated -= 1; self.dial_negotiated -= 1;
@@ -415,7 +415,7 @@ where
&mut self, &mut self,
request: Self::OutboundOpenInfo, request: Self::OutboundOpenInfo,
error: ProtocolsHandlerUpgrErr< error: ProtocolsHandlerUpgrErr<
<Self::OutboundProtocol as OutboundUpgrade<libp2p::swarm::NegotiatedSubstream>>::Error, <Self::OutboundProtocol as OutboundUpgrade<NegotiatedSubstream>>::Error,
>, >,
) { ) {
if let ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(RPCError::IoError(_))) = error { if let ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(RPCError::IoError(_))) = error {

View File

@@ -11,13 +11,14 @@ use crate::rpc::{
methods::ResponseTermination, methods::ResponseTermination,
}; };
use futures::future::*; use futures::future::*;
use futures::prelude::*;
use futures::prelude::{AsyncRead, AsyncWrite};
use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo}; use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo};
use std::io; use std::io;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::Timeout;
use tokio_io_timeout::TimeoutStream; use tokio_io_timeout::TimeoutStream;
use tokio_util::codec::Framed; use tokio_util::codec::Framed;
use types::EthSpec; use types::EthSpec;
@@ -170,13 +171,14 @@ impl ProtocolName for ProtocolId {
pub type InboundOutput<TSocket, TSpec> = (RPCRequest<TSpec>, InboundFramed<TSocket, TSpec>); pub type InboundOutput<TSocket, TSpec> = (RPCRequest<TSpec>, InboundFramed<TSocket, TSpec>);
pub type InboundFramed<TSocket, TSpec> = pub type InboundFramed<TSocket, TSpec> =
Framed<Timeout<TSocket>, InboundCodec<TSpec, RPCErrorResponse<TSpec>>>; Framed<TimeoutStream<TokioNegotiatedStream<TSocket>>, InboundCodec<TSpec>>;
type FnAndThen<TSocket, TSpec> = fn( type FnAndThen<TSocket, TSpec> = fn(
(Option<RPCRequest<TSpec>>, InboundFramed<TSocket, TSpec>), (
Option<Result<RPCRequest<TSpec>, RPCError>>,
InboundFramed<TSocket, TSpec>,
),
) -> Ready<Result<InboundOutput<TSocket, TSpec>, RPCError>>; ) -> Ready<Result<InboundOutput<TSocket, TSpec>, RPCError>>;
// TODO: Error doesn't take a generic parameter in new tokio type FnMapErr = fn(tokio::time::Elapsed) -> RPCError;
// Need to check implications
type FnMapErr = fn(tokio::time::Error) -> RPCError;
impl<TSocket, TSpec> InboundUpgrade<TSocket> for RPCProtocol<TSpec> impl<TSocket, TSpec> InboundUpgrade<TSocket> for RPCProtocol<TSpec>
where where
@@ -189,6 +191,7 @@ where
fn upgrade_inbound(self, socket: TSocket, protocol: ProtocolId) -> Self::Future { fn upgrade_inbound(self, socket: TSocket, protocol: ProtocolId) -> Self::Future {
let protocol_name = protocol.message_name.clone(); let protocol_name = protocol.message_name.clone();
let socket = TokioNegotiatedStream(socket);
let codec = match protocol.encoding { let codec = match protocol.encoding {
Encoding::SSZSnappy => { Encoding::SSZSnappy => {
let ssz_snappy_codec = let ssz_snappy_codec =
@@ -206,27 +209,24 @@ where
let socket = Framed::new(timed_socket, codec); let socket = Framed::new(timed_socket, codec);
// MetaData requests should be empty, return the stream // MetaData requests should be empty, return the stream
match protocol_name { Box::pin(match protocol_name {
Protocol::MetaData => futures::future::Either::A(futures::future::ok(( Protocol::MetaData => {
RPCRequest::MetaData(PhantomData), future::Either::Left(future::ok((RPCRequest::MetaData(PhantomData), socket)))
socket, }
))),
_ => futures::future::Either::B( _ => future::Either::Right(
socket tokio::time::timeout(Duration::from_secs(REQUEST_TIMEOUT), socket.into_future())
.into_future() .map_err(RPCError::from as FnMapErr)
.timeout(Duration::from_secs(REQUEST_TIMEOUT))
.map_err(RPCError::from as FnMapErr<TSocket, TSpec>)
.and_then({ .and_then({
|(req, stream)| match req { |(req, stream)| match req {
Some(request) => futures::future::ok((request, stream)), Some(Ok(request)) => future::ok((request, stream)),
None => futures::future::err(RPCError::Custom( Some(Err(_)) | None => {
"Stream terminated early".into(), err(RPCError::Custom("Stream terminated early".into()))
)), }
} }
} as FnAndThen<TSocket, TSpec>), } as FnAndThen<TSocket, TSpec>),
), ),
} })
} }
} }
@@ -335,11 +335,12 @@ impl<TSpec: EthSpec> RPCRequest<TSpec> {
/* Outbound upgrades */ /* Outbound upgrades */
pub type OutboundFramed<TSocket, TSpec> = Framed<TSocket, OutboundCodec<TSpec, RPCRequest<TSpec>>>; pub type OutboundFramed<TSocket, TSpec> =
Framed<TokioNegotiatedStream<TSocket>, OutboundCodec<TSpec>>;
impl<TSocket, TSpec> OutboundUpgrade<TSocket> for RPCRequest<TSpec> impl<TSocket, TSpec> OutboundUpgrade<TSocket> for RPCRequest<TSpec>
where where
TSpec: EthSpec, TSpec: EthSpec + Send + 'static,
TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static, TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{ {
type Output = OutboundFramed<TSocket, TSpec>; type Output = OutboundFramed<TSocket, TSpec>;
@@ -347,6 +348,7 @@ where
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_outbound(self, socket: TSocket, protocol: Self::Info) -> Self::Future { fn upgrade_outbound(self, socket: TSocket, protocol: Self::Info) -> Self::Future {
let socket = TokioNegotiatedStream(socket);
let codec = match protocol.encoding { let codec = match protocol.encoding {
Encoding::SSZSnappy => { Encoding::SSZSnappy => {
let ssz_snappy_codec = let ssz_snappy_codec =
@@ -359,7 +361,9 @@ where
OutboundCodec::SSZ(ssz_codec) OutboundCodec::SSZ(ssz_codec)
} }
}; };
Box::pin(Framed::new(socket, codec).send(self))
let socket = Framed::new(socket, codec);
Box::pin(future::join(socket.send(self), future::ok(socket)).map(|(_, socket)| socket))
} }
} }
@@ -397,13 +401,9 @@ impl From<ssz::DecodeError> for RPCError {
RPCError::SSZDecodeError(err) RPCError::SSZDecodeError(err)
} }
} }
impl From<tokio::time::Error> for RPCError { impl From<tokio::time::Elapsed> for RPCError {
fn from(err: tokio::time::Error) -> Self { fn from(err: tokio::time::Elapsed) -> Self {
if err.is_elapsed() { RPCError::StreamTimeout
RPCError::StreamTimeout
} else {
RPCError::Custom("Stream timer failed".into())
}
} }
} }
@@ -468,3 +468,33 @@ impl<TSpec: EthSpec> std::fmt::Display for RPCRequest<TSpec> {
} }
} }
} }
/// Converts a futures AsyncRead + AsyncWrite object to a tokio::AsyncRead + tokio::AsyncWrite
/// object.
struct TokioNegotiatedStream<T: AsyncRead + AsyncWrite + Unpin>(T);
impl<T: AsyncRead + AsyncWrite + Unpin> tokio::io::AsyncRead for TokioNegotiatedStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> tokio::io::AsyncWrite for TokioNegotiatedStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_close(cx)
}
}