Finished first round of fighting RPC types

This commit is contained in:
Age Manning
2020-05-01 20:05:03 +10:00
parent 08838fca23
commit 9e6ae448a6
5 changed files with 93 additions and 61 deletions

View File

@@ -11,13 +11,14 @@ use crate::rpc::{
methods::ResponseTermination,
};
use futures::future::*;
use futures::prelude::*;
use futures::prelude::{AsyncRead, AsyncWrite};
use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo};
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::Timeout;
use tokio_io_timeout::TimeoutStream;
use tokio_util::codec::Framed;
use types::EthSpec;
@@ -170,13 +171,14 @@ impl ProtocolName for ProtocolId {
pub type InboundOutput<TSocket, TSpec> = (RPCRequest<TSpec>, InboundFramed<TSocket, TSpec>);
pub type InboundFramed<TSocket, TSpec> =
Framed<Timeout<TSocket>, InboundCodec<TSpec, RPCErrorResponse<TSpec>>>;
Framed<TimeoutStream<TokioNegotiatedStream<TSocket>>, InboundCodec<TSpec>>;
type FnAndThen<TSocket, TSpec> = fn(
(Option<RPCRequest<TSpec>>, InboundFramed<TSocket, TSpec>),
(
Option<Result<RPCRequest<TSpec>, RPCError>>,
InboundFramed<TSocket, TSpec>,
),
) -> Ready<Result<InboundOutput<TSocket, TSpec>, RPCError>>;
// TODO: Error doesn't take a generic parameter in new tokio
// Need to check implications
type FnMapErr = fn(tokio::time::Error) -> RPCError;
type FnMapErr = fn(tokio::time::Elapsed) -> RPCError;
impl<TSocket, TSpec> InboundUpgrade<TSocket> for RPCProtocol<TSpec>
where
@@ -189,6 +191,7 @@ where
fn upgrade_inbound(self, socket: TSocket, protocol: ProtocolId) -> Self::Future {
let protocol_name = protocol.message_name.clone();
let socket = TokioNegotiatedStream(socket);
let codec = match protocol.encoding {
Encoding::SSZSnappy => {
let ssz_snappy_codec =
@@ -206,27 +209,24 @@ where
let socket = Framed::new(timed_socket, codec);
// MetaData requests should be empty, return the stream
match protocol_name {
Protocol::MetaData => futures::future::Either::A(futures::future::ok((
RPCRequest::MetaData(PhantomData),
socket,
))),
Box::pin(match protocol_name {
Protocol::MetaData => {
future::Either::Left(future::ok((RPCRequest::MetaData(PhantomData), socket)))
}
_ => futures::future::Either::B(
socket
.into_future()
.timeout(Duration::from_secs(REQUEST_TIMEOUT))
.map_err(RPCError::from as FnMapErr<TSocket, TSpec>)
_ => future::Either::Right(
tokio::time::timeout(Duration::from_secs(REQUEST_TIMEOUT), socket.into_future())
.map_err(RPCError::from as FnMapErr)
.and_then({
|(req, stream)| match req {
Some(request) => futures::future::ok((request, stream)),
None => futures::future::err(RPCError::Custom(
"Stream terminated early".into(),
)),
Some(Ok(request)) => future::ok((request, stream)),
Some(Err(_)) | None => {
err(RPCError::Custom("Stream terminated early".into()))
}
}
} as FnAndThen<TSocket, TSpec>),
),
}
})
}
}
@@ -335,11 +335,12 @@ impl<TSpec: EthSpec> RPCRequest<TSpec> {
/* Outbound upgrades */
pub type OutboundFramed<TSocket, TSpec> = Framed<TSocket, OutboundCodec<TSpec, RPCRequest<TSpec>>>;
pub type OutboundFramed<TSocket, TSpec> =
Framed<TokioNegotiatedStream<TSocket>, OutboundCodec<TSpec>>;
impl<TSocket, TSpec> OutboundUpgrade<TSocket> for RPCRequest<TSpec>
where
TSpec: EthSpec,
TSpec: EthSpec + Send + 'static,
TSocket: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = OutboundFramed<TSocket, TSpec>;
@@ -347,6 +348,7 @@ where
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn upgrade_outbound(self, socket: TSocket, protocol: Self::Info) -> Self::Future {
let socket = TokioNegotiatedStream(socket);
let codec = match protocol.encoding {
Encoding::SSZSnappy => {
let ssz_snappy_codec =
@@ -359,7 +361,9 @@ where
OutboundCodec::SSZ(ssz_codec)
}
};
Box::pin(Framed::new(socket, codec).send(self))
let socket = Framed::new(socket, codec);
Box::pin(future::join(socket.send(self), future::ok(socket)).map(|(_, socket)| socket))
}
}
@@ -397,13 +401,9 @@ impl From<ssz::DecodeError> for RPCError {
RPCError::SSZDecodeError(err)
}
}
impl From<tokio::time::Error> for RPCError {
fn from(err: tokio::time::Error) -> Self {
if err.is_elapsed() {
RPCError::StreamTimeout
} else {
RPCError::Custom("Stream timer failed".into())
}
impl From<tokio::time::Elapsed> for RPCError {
fn from(err: tokio::time::Elapsed) -> Self {
RPCError::StreamTimeout
}
}
@@ -468,3 +468,33 @@ impl<TSpec: EthSpec> std::fmt::Display for RPCRequest<TSpec> {
}
}
}
/// Converts a futures AsyncRead + AsyncWrite object to a tokio::AsyncRead + tokio::AsyncWrite
/// object.
struct TokioNegotiatedStream<T: AsyncRead + AsyncWrite + Unpin>(T);
impl<T: AsyncRead + AsyncWrite + Unpin> tokio::io::AsyncRead for TokioNegotiatedStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> tokio::io::AsyncWrite for TokioNegotiatedStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_close(cx)
}
}