mirror of
https://github.com/sigp/lighthouse.git
synced 2026-07-01 11:54:40 +00:00
Bump warp and begin axum migration (#9001)
- Bump `warp` to 0.4. This unifies `warp` and `axum` onto the same `http`, `hyper`, `h2`, `rustls`, etc versions. - Create `axum_utils` which contain common functions and types - Begins migration of all HTTP API servers from warp to axum Co-Authored-By: Mac L <mjladson@pm.me>
This commit is contained in:
17
common/axum_utils/Cargo.toml
Normal file
17
common/axum_utils/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "axum_utils"
|
||||
version = "0.1.0"
|
||||
authors = ["Sigma Prime <contact@sigmaprime.io>"]
|
||||
edition = { workspace = true }
|
||||
|
||||
[dependencies]
|
||||
axum = { workspace = true }
|
||||
axum-server = { version = "0.7", features = ["tls-rustls-no-provider"] }
|
||||
http = "1"
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tower-http = { version = "0.6", features = ["cors", "set-header"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tower = { workspace = true }
|
||||
515
common/axum_utils/src/cors.rs
Normal file
515
common/axum_utils/src/cors.rs
Normal file
@@ -0,0 +1,515 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
|
||||
/// Errors that can occur during CORS configuration
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CorsError {
|
||||
#[error("Invalid CORS origin '{origin}': {reason}")]
|
||||
InvalidOrigin { origin: String, reason: String },
|
||||
|
||||
#[error("CORS origins string cannot be empty")]
|
||||
EmptyOriginsString,
|
||||
}
|
||||
|
||||
/// A validated CORS origin
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Origin {
|
||||
/// Allow any origin (*).
|
||||
Any,
|
||||
/// A specific origin URL.
|
||||
Exact(http::HeaderValue),
|
||||
}
|
||||
|
||||
impl FromStr for Origin {
|
||||
type Err = CorsError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let trimmed = s.trim();
|
||||
|
||||
if trimmed == "*" {
|
||||
return Ok(Origin::Any);
|
||||
}
|
||||
|
||||
validate_origin(trimmed)?;
|
||||
|
||||
let header_value =
|
||||
http::HeaderValue::from_str(trimmed).map_err(|e| CorsError::InvalidOrigin {
|
||||
origin: trimmed.to_string(),
|
||||
reason: format!("invalid header value: {}", e),
|
||||
})?;
|
||||
|
||||
Ok(Origin::Exact(header_value))
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a CORS origin string
|
||||
fn validate_origin(s: &str) -> Result<(), CorsError> {
|
||||
let make_error = |reason: &str| CorsError::InvalidOrigin {
|
||||
origin: s.to_string(),
|
||||
reason: reason.to_string(),
|
||||
};
|
||||
|
||||
if !s.contains("://") {
|
||||
return Err(make_error("missing scheme (http:// or https://)"));
|
||||
}
|
||||
|
||||
let (scheme, rest) = s
|
||||
.split_once("://")
|
||||
.ok_or_else(|| make_error("failed to parse scheme"))?;
|
||||
|
||||
if !matches!(scheme, "http" | "https") {
|
||||
return Err(make_error(&format!(
|
||||
"invalid scheme '{}' (only http and https are allowed)",
|
||||
scheme
|
||||
)));
|
||||
}
|
||||
|
||||
if rest.is_empty() {
|
||||
return Err(make_error("missing host"));
|
||||
}
|
||||
|
||||
let host = rest
|
||||
.split(':')
|
||||
.next()
|
||||
.ok_or_else(|| make_error("failed to extract host"))?;
|
||||
|
||||
if host.is_empty() {
|
||||
return Err(make_error("empty host"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configuration for CORS.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CorsConfig {
|
||||
/// Comma-separated list of allowed origins, or "*" for any origin.
|
||||
pub allowed_origins: String,
|
||||
}
|
||||
|
||||
impl CorsConfig {
|
||||
/// Create a new CORS config from an origins string.
|
||||
pub fn new(allowed_origins: impl Into<String>) -> Self {
|
||||
Self {
|
||||
allowed_origins: allowed_origins.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse and validate the origins string, returning a list of validated origins.
|
||||
pub fn parse_origins(&self) -> Result<Vec<Origin>, CorsError> {
|
||||
let trimmed = self.allowed_origins.trim();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
return Err(CorsError::EmptyOriginsString);
|
||||
}
|
||||
|
||||
let origins: Vec<Origin> = trimmed
|
||||
.split(',')
|
||||
.map(|s| s.trim().parse::<Origin>())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(origins)
|
||||
}
|
||||
|
||||
/// Convert this config into a Tower-compatible CorsLayer.
|
||||
pub fn into_layer(self) -> Result<CorsLayer, CorsError> {
|
||||
let origins = self.parse_origins()?;
|
||||
|
||||
// If any origin is the wildcard `*`, allow all origins, even when the
|
||||
// list also contains explicit origins.
|
||||
if origins.iter().any(|o| matches!(o, Origin::Any)) {
|
||||
return Ok(CorsLayer::new().allow_origin(tower_http::cors::Any));
|
||||
}
|
||||
|
||||
let header_values: Vec<http::HeaderValue> = origins
|
||||
.into_iter()
|
||||
.filter_map(|o| match o {
|
||||
Origin::Exact(hv) => Some(hv),
|
||||
Origin::Any => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(CorsLayer::new().allow_origin(AllowOrigin::list(header_values)))
|
||||
}
|
||||
}
|
||||
|
||||
fn format_default_origin(ip: IpAddr, port: u16) -> String {
|
||||
match ip {
|
||||
IpAddr::V4(addr) => format!("http://{}:{}", addr, port),
|
||||
IpAddr::V6(addr) => format!("http://[{}]:{}", addr, port),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a CORS layer from an optional origins string and a default fallback
|
||||
///
|
||||
/// This is the main function for CLI usage:
|
||||
/// - If `allow_origin` is `Some`, parse it as comma-separated origins
|
||||
/// - If `allow_origin` is `None`, use the default IP and port
|
||||
///
|
||||
/// Callers can chain additional configuration like `.allow_methods()` and
|
||||
/// `.allow_headers()` on the returned `CorsLayer`.
|
||||
pub fn build_cors_layer(
|
||||
allow_origin: Option<&str>,
|
||||
default_ip: IpAddr,
|
||||
default_port: u16,
|
||||
) -> Result<CorsLayer, CorsError> {
|
||||
let origins = match allow_origin {
|
||||
Some(s) if !s.trim().is_empty() => s.to_string(),
|
||||
_ => format_default_origin(default_ip, default_port),
|
||||
};
|
||||
|
||||
CorsConfig {
|
||||
allowed_origins: origins,
|
||||
}
|
||||
.into_layer()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::{Router, routing::get};
|
||||
use http::{Request, StatusCode};
|
||||
use tower::ServiceExt;
|
||||
|
||||
fn parse_origin(s: &str) -> Result<(), CorsError> {
|
||||
s.parse::<Origin>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_origins() {
|
||||
parse_origin("*").unwrap();
|
||||
parse_origin("http://127.0.0.1").unwrap();
|
||||
parse_origin("http://localhost").unwrap();
|
||||
parse_origin("http://127.0.0.1:8000").unwrap();
|
||||
parse_origin("http://localhost:8000").unwrap();
|
||||
parse_origin("http://[::1]").unwrap();
|
||||
parse_origin("http://[::1]:8000").unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_origins() {
|
||||
parse_origin(".*").unwrap_err();
|
||||
parse_origin("127.0.0.1").unwrap_err();
|
||||
parse_origin("localhost").unwrap_err();
|
||||
parse_origin("[::1]").unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn origin_variants() {
|
||||
assert_eq!("*".parse::<Origin>().unwrap(), Origin::Any);
|
||||
|
||||
match "http://localhost:3000".parse::<Origin>().unwrap() {
|
||||
Origin::Exact(_) => {}
|
||||
Origin::Any => panic!("Expected Exact variant, got Any"),
|
||||
}
|
||||
|
||||
match "https://example.com".parse::<Origin>().unwrap() {
|
||||
Origin::Exact(_) => {}
|
||||
Origin::Any => panic!("Expected Exact variant, got Any"),
|
||||
}
|
||||
}
|
||||
|
||||
struct HttpConfig {
|
||||
allow_origin: Option<String>,
|
||||
listen_addr: IpAddr,
|
||||
listen_port: u16,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config() {
|
||||
let config = HttpConfig {
|
||||
allow_origin: None,
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_origin() {
|
||||
// lighthouse bn --http-allow-origin "*"
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("*".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_origin() {
|
||||
// lighthouse bn --http-allow-origin "http://localhost:3000"
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_origins() {
|
||||
// lighthouse bn --http-allow-origin "http://localhost:3000,https://example.com"
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000,https://example.com".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ipv6_listen_address() {
|
||||
let config = HttpConfig {
|
||||
allow_origin: None,
|
||||
listen_addr: "::1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_origin_missing_scheme() {
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("localhost:3000".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_origin_in_list() {
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000,invalid,https://example.com".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
);
|
||||
assert!(layer.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_cors_layer() {
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let cors_layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
async fn handler() -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
let app = Router::new().route("/", get(handler)).layer(cors_layer);
|
||||
|
||||
// Preflight request
|
||||
let request = Request::builder()
|
||||
.method("OPTIONS")
|
||||
.uri("/")
|
||||
.header("Origin", "http://localhost:3000")
|
||||
.header("Access-Control-Request-Method", "GET")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// Verify CORS header matches origin
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.unwrap(),
|
||||
"http://localhost:3000"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wildcard_overrides_exact() {
|
||||
// Mix specific origin with wildcard
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000,*".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let cors_layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
async fn handler() -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
let app = Router::new().route("/", get(handler)).layer(cors_layer);
|
||||
|
||||
let request = Request::builder()
|
||||
.method("OPTIONS")
|
||||
.uri("/")
|
||||
.header("Origin", "https://completely-different-origin.com")
|
||||
.header("Access-Control-Request-Method", "GET")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
response
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.unwrap(),
|
||||
"*"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_allowed_methods() {
|
||||
use axum::{Router, routing::get};
|
||||
use http::{Request, StatusCode};
|
||||
use tower::ServiceExt;
|
||||
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let cors_layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
)
|
||||
.unwrap()
|
||||
.allow_methods([http::Method::GET, http::Method::POST])
|
||||
.allow_headers([http::header::CONTENT_TYPE]);
|
||||
|
||||
async fn handler() -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
let app = Router::new().route("/", get(handler)).layer(cors_layer);
|
||||
|
||||
let request = Request::builder()
|
||||
.method("OPTIONS")
|
||||
.uri("/")
|
||||
.header("Origin", "http://localhost:3000")
|
||||
.header("Access-Control-Request-Method", "GET")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.clone().oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let allowed_methods = response
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap();
|
||||
assert!(allowed_methods.contains("GET"));
|
||||
assert!(allowed_methods.contains("POST"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_allowed_headers() {
|
||||
let config = HttpConfig {
|
||||
allow_origin: Some("http://localhost:3000".to_string()),
|
||||
listen_addr: "127.0.0.1".parse().unwrap(),
|
||||
listen_port: 5052,
|
||||
};
|
||||
|
||||
let cors_layer = build_cors_layer(
|
||||
config.allow_origin.as_deref(),
|
||||
config.listen_addr,
|
||||
config.listen_port,
|
||||
)
|
||||
.unwrap()
|
||||
.allow_methods([http::Method::GET])
|
||||
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION]);
|
||||
|
||||
async fn handler() -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
let app = Router::new().route("/", get(handler)).layer(cors_layer);
|
||||
|
||||
// Preflight request with Content-Type header
|
||||
let request = Request::builder()
|
||||
.method("OPTIONS")
|
||||
.uri("/")
|
||||
.header("Origin", "http://localhost:3000")
|
||||
.header("Access-Control-Request-Method", "GET")
|
||||
.header("Access-Control-Request-Headers", "content-type")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let allowed_headers = response
|
||||
.headers()
|
||||
.get("access-control-allow-headers")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_lowercase();
|
||||
assert!(allowed_headers.contains("content-type"));
|
||||
assert!(allowed_headers.contains("authorization"));
|
||||
}
|
||||
}
|
||||
6
common/axum_utils/src/lib.rs
Normal file
6
common/axum_utils/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod cors;
|
||||
pub mod middleware;
|
||||
pub mod server;
|
||||
pub mod tls;
|
||||
|
||||
pub use server::Server;
|
||||
9
common/axum_utils/src/middleware.rs
Normal file
9
common/axum_utils/src/middleware.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use axum::http::header;
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
|
||||
/// Returns a layer that adds the "Server" header to all responses.
|
||||
pub fn add_server_header(
|
||||
value: header::HeaderValue,
|
||||
) -> SetResponseHeaderLayer<header::HeaderValue> {
|
||||
SetResponseHeaderLayer::overriding(header::SERVER, value)
|
||||
}
|
||||
42
common/axum_utils/src/server/builder.rs
Normal file
42
common/axum_utils/src/server/builder.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use crate::{
|
||||
server::{Server, error::BuilderError},
|
||||
tls::TlsConfig,
|
||||
};
|
||||
use axum::Router;
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub struct ServerBuilder {
|
||||
router: Router,
|
||||
address: SocketAddr,
|
||||
tls_config: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
impl ServerBuilder {
|
||||
pub fn new(router: Router, address: SocketAddr) -> Self {
|
||||
Self {
|
||||
router,
|
||||
address,
|
||||
tls_config: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tls(mut self, config: TlsConfig) -> Self {
|
||||
self.tls_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self) -> Result<Server, BuilderError> {
|
||||
let rustls_config = if let Some(tls) = self.tls_config {
|
||||
Some(RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Server {
|
||||
router: self.router,
|
||||
address: self.address,
|
||||
rustls_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
11
common/axum_utils/src/server/error.rs
Normal file
11
common/axum_utils/src/server/error.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum BuilderError {
|
||||
#[error("TLS configuration failed: {0}")]
|
||||
TlsConfigFailed(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ServerError {
|
||||
#[error("Server failed: {0}")]
|
||||
ServerFailed(#[from] std::io::Error),
|
||||
}
|
||||
116
common/axum_utils/src/server/mod.rs
Normal file
116
common/axum_utils/src/server/mod.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
mod builder;
|
||||
mod error;
|
||||
|
||||
pub use builder::ServerBuilder;
|
||||
pub use error::{BuilderError, ServerError};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default timeout for graceful shutdown. After this duration, the server will
|
||||
/// stop waiting for existing connections and shut down immediately.
|
||||
const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
pub struct Server {
|
||||
pub(crate) router: Router,
|
||||
pub(crate) address: SocketAddr,
|
||||
pub(crate) rustls_config: Option<RustlsConfig>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
/// Initialize a new server builder.
|
||||
pub fn builder(router: Router, address: SocketAddr) -> ServerBuilder {
|
||||
ServerBuilder::new(router, address)
|
||||
}
|
||||
|
||||
/// Get information about the server configuration.
|
||||
///
|
||||
/// Note that the address is only the configured address, not the actual address
|
||||
/// the server is listening on (such as when using port 0).
|
||||
pub fn info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
address: self.address,
|
||||
protocol: if self.rustls_config.is_some() {
|
||||
Protocol::Https
|
||||
} else {
|
||||
Protocol::Http
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Serve the application until the shutdown signal is received.
|
||||
/// Returns the actual address the server is listening on.
|
||||
pub async fn serve_with_shutdown<F>(
|
||||
self,
|
||||
shutdown_signal: F,
|
||||
) -> Result<(SocketAddr, impl Future<Output = Result<(), ServerError>>), ServerError>
|
||||
where
|
||||
F: std::future::Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let tokio_listener = tokio::net::TcpListener::bind(self.address).await?;
|
||||
|
||||
let actual_addr = tokio_listener.local_addr()?;
|
||||
|
||||
let listener = tokio_listener.into_std()?;
|
||||
|
||||
let handle = axum_server::Handle::new();
|
||||
|
||||
let server_future = async move {
|
||||
// Spawn a task that triggers graceful shutdown when the signal fires.
|
||||
// If the server exits before the signal, this task will linger until the
|
||||
// signal resolves, which is harmless.
|
||||
let shutdown_handle = tokio::spawn({
|
||||
let handle = handle.clone();
|
||||
async move {
|
||||
shutdown_signal.await;
|
||||
handle.graceful_shutdown(Some(DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT));
|
||||
}
|
||||
});
|
||||
|
||||
let result = match self.rustls_config {
|
||||
Some(config) => {
|
||||
axum_server::from_tcp_rustls(listener, config)
|
||||
.handle(handle)
|
||||
.serve(self.router.into_make_service())
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
axum_server::from_tcp(listener)
|
||||
.handle(handle)
|
||||
.serve(self.router.into_make_service())
|
||||
.await
|
||||
}
|
||||
};
|
||||
|
||||
// Abort the shutdown listener if it's still running (server exited first).
|
||||
shutdown_handle.abort();
|
||||
|
||||
result.map_err(ServerError::ServerFailed)
|
||||
};
|
||||
|
||||
Ok((actual_addr, server_future))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ServerInfo {
|
||||
pub address: SocketAddr,
|
||||
pub protocol: Protocol,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Protocol {
|
||||
Http,
|
||||
Https,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Protocol {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Protocol::Http => write!(f, "http"),
|
||||
Protocol::Https => write!(f, "https"),
|
||||
}
|
||||
}
|
||||
}
|
||||
9
common/axum_utils/src/tls.rs
Normal file
9
common/axum_utils/src/tls.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Configuration used when serving the HTTP server over TLS.
|
||||
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsConfig {
|
||||
pub cert: PathBuf,
|
||||
pub key: PathBuf,
|
||||
}
|
||||
@@ -405,7 +405,7 @@ impl TaskExecutor {
|
||||
|
||||
/// Returns a future that completes when `async-channel::Sender` is dropped or () is sent,
|
||||
/// which translates to the exit signal being triggered.
|
||||
pub fn exit(&self) -> impl Future<Output = ()> + 'static {
|
||||
pub fn exit(&self) -> impl Future<Output = ()> + use<> + 'static {
|
||||
let exit = self.exit.clone();
|
||||
async move {
|
||||
let _ = exit.recv().await;
|
||||
|
||||
@@ -8,7 +8,7 @@ edition = { workspace = true }
|
||||
[dependencies]
|
||||
bytes = { workspace = true }
|
||||
eth2 = { workspace = true }
|
||||
headers = "0.3.2"
|
||||
headers = "0.4"
|
||||
reqwest = { workspace = true }
|
||||
safe_arith = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use reqwest::StatusCode;
|
||||
use warp::Rejection;
|
||||
|
||||
/// Convert from a "new" `http::StatusCode` to a `warp` compatible one.
|
||||
pub fn convert(code: StatusCode) -> Result<warp::http::StatusCode, Rejection> {
|
||||
code.as_u16().try_into().map_err(|e| {
|
||||
crate::reject::custom_server_error(format!("bad status code {code:?} - {e:?}"))
|
||||
})
|
||||
/// Convert a `reqwest::StatusCode` to a `warp::http::StatusCode`.
|
||||
///
|
||||
/// In warp 0.4, both `reqwest` (0.12) and `warp` use the `http` v1 crate,
|
||||
/// so `reqwest::StatusCode` and `warp::http::StatusCode` are the same type.
|
||||
/// This function is retained for API compatibility but is now a no-op.
|
||||
pub fn convert(code: StatusCode) -> Result<warp::http::StatusCode, warp::Rejection> {
|
||||
Ok(code)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user