Bump warp and begin axum migration (#9001)

- Bump `warp` to 0.4. This unifies `warp` and `axum` onto the same `http`, `hyper`, `h2`, `rustls`, etc versions.
- Create `axum_utils` which contain common functions and types
- Begins migration of all HTTP API servers from warp to axum


Co-Authored-By: Mac L <mjladson@pm.me>
This commit is contained in:
Mac L
2026-06-25 18:19:29 +04:00
committed by GitHub
parent a4c4cccf04
commit 8c2a909061
41 changed files with 1333 additions and 543 deletions

View File

@@ -0,0 +1,17 @@
[package]
name = "axum_utils"
version = "0.1.0"
authors = ["Sigma Prime <contact@sigmaprime.io>"]
edition = { workspace = true }
[dependencies]
axum = { workspace = true }
axum-server = { version = "0.7", features = ["tls-rustls-no-provider"] }
http = "1"
serde = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tower-http = { version = "0.6", features = ["cors", "set-header"] }
[dev-dependencies]
tower = { workspace = true }

View File

@@ -0,0 +1,515 @@
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::str::FromStr;
use tower_http::cors::{AllowOrigin, CorsLayer};
/// Errors that can occur during CORS configuration
#[derive(Debug, thiserror::Error)]
pub enum CorsError {
#[error("Invalid CORS origin '{origin}': {reason}")]
InvalidOrigin { origin: String, reason: String },
#[error("CORS origins string cannot be empty")]
EmptyOriginsString,
}
/// A validated CORS origin
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Origin {
/// Allow any origin (*).
Any,
/// A specific origin URL.
Exact(http::HeaderValue),
}
impl FromStr for Origin {
type Err = CorsError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let trimmed = s.trim();
if trimmed == "*" {
return Ok(Origin::Any);
}
validate_origin(trimmed)?;
let header_value =
http::HeaderValue::from_str(trimmed).map_err(|e| CorsError::InvalidOrigin {
origin: trimmed.to_string(),
reason: format!("invalid header value: {}", e),
})?;
Ok(Origin::Exact(header_value))
}
}
/// Validate a CORS origin string
fn validate_origin(s: &str) -> Result<(), CorsError> {
let make_error = |reason: &str| CorsError::InvalidOrigin {
origin: s.to_string(),
reason: reason.to_string(),
};
if !s.contains("://") {
return Err(make_error("missing scheme (http:// or https://)"));
}
let (scheme, rest) = s
.split_once("://")
.ok_or_else(|| make_error("failed to parse scheme"))?;
if !matches!(scheme, "http" | "https") {
return Err(make_error(&format!(
"invalid scheme '{}' (only http and https are allowed)",
scheme
)));
}
if rest.is_empty() {
return Err(make_error("missing host"));
}
let host = rest
.split(':')
.next()
.ok_or_else(|| make_error("failed to extract host"))?;
if host.is_empty() {
return Err(make_error("empty host"));
}
Ok(())
}
/// Configuration for CORS.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsConfig {
/// Comma-separated list of allowed origins, or "*" for any origin.
pub allowed_origins: String,
}
impl CorsConfig {
/// Create a new CORS config from an origins string.
pub fn new(allowed_origins: impl Into<String>) -> Self {
Self {
allowed_origins: allowed_origins.into(),
}
}
/// Parse and validate the origins string, returning a list of validated origins.
pub fn parse_origins(&self) -> Result<Vec<Origin>, CorsError> {
let trimmed = self.allowed_origins.trim();
if trimmed.is_empty() {
return Err(CorsError::EmptyOriginsString);
}
let origins: Vec<Origin> = trimmed
.split(',')
.map(|s| s.trim().parse::<Origin>())
.collect::<Result<Vec<_>, _>>()?;
Ok(origins)
}
/// Convert this config into a Tower-compatible CorsLayer.
pub fn into_layer(self) -> Result<CorsLayer, CorsError> {
let origins = self.parse_origins()?;
// If any origin is the wildcard `*`, allow all origins, even when the
// list also contains explicit origins.
if origins.iter().any(|o| matches!(o, Origin::Any)) {
return Ok(CorsLayer::new().allow_origin(tower_http::cors::Any));
}
let header_values: Vec<http::HeaderValue> = origins
.into_iter()
.filter_map(|o| match o {
Origin::Exact(hv) => Some(hv),
Origin::Any => None,
})
.collect();
Ok(CorsLayer::new().allow_origin(AllowOrigin::list(header_values)))
}
}
fn format_default_origin(ip: IpAddr, port: u16) -> String {
match ip {
IpAddr::V4(addr) => format!("http://{}:{}", addr, port),
IpAddr::V6(addr) => format!("http://[{}]:{}", addr, port),
}
}
/// Build a CORS layer from an optional origins string and a default fallback
///
/// This is the main function for CLI usage:
/// - If `allow_origin` is `Some`, parse it as comma-separated origins
/// - If `allow_origin` is `None`, use the default IP and port
///
/// Callers can chain additional configuration like `.allow_methods()` and
/// `.allow_headers()` on the returned `CorsLayer`.
pub fn build_cors_layer(
allow_origin: Option<&str>,
default_ip: IpAddr,
default_port: u16,
) -> Result<CorsLayer, CorsError> {
let origins = match allow_origin {
Some(s) if !s.trim().is_empty() => s.to_string(),
_ => format_default_origin(default_ip, default_port),
};
CorsConfig {
allowed_origins: origins,
}
.into_layer()
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{Router, routing::get};
use http::{Request, StatusCode};
use tower::ServiceExt;
fn parse_origin(s: &str) -> Result<(), CorsError> {
s.parse::<Origin>()?;
Ok(())
}
#[test]
fn valid_origins() {
parse_origin("*").unwrap();
parse_origin("http://127.0.0.1").unwrap();
parse_origin("http://localhost").unwrap();
parse_origin("http://127.0.0.1:8000").unwrap();
parse_origin("http://localhost:8000").unwrap();
parse_origin("http://[::1]").unwrap();
parse_origin("http://[::1]:8000").unwrap();
}
#[test]
fn invalid_origins() {
parse_origin(".*").unwrap_err();
parse_origin("127.0.0.1").unwrap_err();
parse_origin("localhost").unwrap_err();
parse_origin("[::1]").unwrap_err();
}
#[test]
fn origin_variants() {
assert_eq!("*".parse::<Origin>().unwrap(), Origin::Any);
match "http://localhost:3000".parse::<Origin>().unwrap() {
Origin::Exact(_) => {}
Origin::Any => panic!("Expected Exact variant, got Any"),
}
match "https://example.com".parse::<Origin>().unwrap() {
Origin::Exact(_) => {}
Origin::Any => panic!("Expected Exact variant, got Any"),
}
}
struct HttpConfig {
allow_origin: Option<String>,
listen_addr: IpAddr,
listen_port: u16,
}
#[test]
fn default_config() {
let config = HttpConfig {
allow_origin: None,
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_ok());
}
#[test]
fn wildcard_origin() {
// lighthouse bn --http-allow-origin "*"
let config = HttpConfig {
allow_origin: Some("*".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_ok());
}
#[test]
fn single_origin() {
// lighthouse bn --http-allow-origin "http://localhost:3000"
let config = HttpConfig {
allow_origin: Some("http://localhost:3000".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_ok());
}
#[test]
fn multiple_origins() {
// lighthouse bn --http-allow-origin "http://localhost:3000,https://example.com"
let config = HttpConfig {
allow_origin: Some("http://localhost:3000,https://example.com".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_ok());
}
#[test]
fn ipv6_listen_address() {
let config = HttpConfig {
allow_origin: None,
listen_addr: "::1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_ok());
}
#[test]
fn invalid_origin_missing_scheme() {
let config = HttpConfig {
allow_origin: Some("localhost:3000".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_err());
}
#[test]
fn invalid_origin_in_list() {
let config = HttpConfig {
allow_origin: Some("http://localhost:3000,invalid,https://example.com".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
);
assert!(layer.is_err());
}
#[tokio::test]
async fn verify_cors_layer() {
let config = HttpConfig {
allow_origin: Some("http://localhost:3000".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let cors_layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
)
.unwrap();
async fn handler() -> &'static str {
"test"
}
let app = Router::new().route("/", get(handler)).layer(cors_layer);
// Preflight request
let request = Request::builder()
.method("OPTIONS")
.uri("/")
.header("Origin", "http://localhost:3000")
.header("Access-Control-Request-Method", "GET")
.body(axum::body::Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
// Verify CORS header matches origin
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("access-control-allow-origin")
.unwrap(),
"http://localhost:3000"
);
}
#[tokio::test]
async fn wildcard_overrides_exact() {
// Mix specific origin with wildcard
let config = HttpConfig {
allow_origin: Some("http://localhost:3000,*".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let cors_layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
)
.unwrap();
async fn handler() -> &'static str {
"test"
}
let app = Router::new().route("/", get(handler)).layer(cors_layer);
let request = Request::builder()
.method("OPTIONS")
.uri("/")
.header("Origin", "https://completely-different-origin.com")
.header("Access-Control-Request-Method", "GET")
.body(axum::body::Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response
.headers()
.get("access-control-allow-origin")
.unwrap(),
"*"
);
}
#[tokio::test]
async fn verify_allowed_methods() {
use axum::{Router, routing::get};
use http::{Request, StatusCode};
use tower::ServiceExt;
let config = HttpConfig {
allow_origin: Some("http://localhost:3000".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let cors_layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
)
.unwrap()
.allow_methods([http::Method::GET, http::Method::POST])
.allow_headers([http::header::CONTENT_TYPE]);
async fn handler() -> &'static str {
"test"
}
let app = Router::new().route("/", get(handler)).layer(cors_layer);
let request = Request::builder()
.method("OPTIONS")
.uri("/")
.header("Origin", "http://localhost:3000")
.header("Access-Control-Request-Method", "GET")
.body(axum::body::Body::empty())
.unwrap();
let response = app.clone().oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let allowed_methods = response
.headers()
.get("access-control-allow-methods")
.unwrap()
.to_str()
.unwrap();
assert!(allowed_methods.contains("GET"));
assert!(allowed_methods.contains("POST"));
}
#[tokio::test]
async fn verify_allowed_headers() {
let config = HttpConfig {
allow_origin: Some("http://localhost:3000".to_string()),
listen_addr: "127.0.0.1".parse().unwrap(),
listen_port: 5052,
};
let cors_layer = build_cors_layer(
config.allow_origin.as_deref(),
config.listen_addr,
config.listen_port,
)
.unwrap()
.allow_methods([http::Method::GET])
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION]);
async fn handler() -> &'static str {
"test"
}
let app = Router::new().route("/", get(handler)).layer(cors_layer);
// Preflight request with Content-Type header
let request = Request::builder()
.method("OPTIONS")
.uri("/")
.header("Origin", "http://localhost:3000")
.header("Access-Control-Request-Method", "GET")
.header("Access-Control-Request-Headers", "content-type")
.body(axum::body::Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let allowed_headers = response
.headers()
.get("access-control-allow-headers")
.unwrap()
.to_str()
.unwrap()
.to_lowercase();
assert!(allowed_headers.contains("content-type"));
assert!(allowed_headers.contains("authorization"));
}
}

View File

@@ -0,0 +1,6 @@
pub mod cors;
pub mod middleware;
pub mod server;
pub mod tls;
pub use server::Server;

View File

@@ -0,0 +1,9 @@
use axum::http::header;
use tower_http::set_header::SetResponseHeaderLayer;
/// Returns a layer that adds the "Server" header to all responses.
pub fn add_server_header(
value: header::HeaderValue,
) -> SetResponseHeaderLayer<header::HeaderValue> {
SetResponseHeaderLayer::overriding(header::SERVER, value)
}

View File

@@ -0,0 +1,42 @@
use crate::{
server::{Server, error::BuilderError},
tls::TlsConfig,
};
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use std::net::SocketAddr;
pub struct ServerBuilder {
router: Router,
address: SocketAddr,
tls_config: Option<TlsConfig>,
}
impl ServerBuilder {
pub fn new(router: Router, address: SocketAddr) -> Self {
Self {
router,
address,
tls_config: None,
}
}
pub fn with_tls(mut self, config: TlsConfig) -> Self {
self.tls_config = Some(config);
self
}
pub async fn build(self) -> Result<Server, BuilderError> {
let rustls_config = if let Some(tls) = self.tls_config {
Some(RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?)
} else {
None
};
Ok(Server {
router: self.router,
address: self.address,
rustls_config,
})
}
}

View File

@@ -0,0 +1,11 @@
#[derive(Debug, thiserror::Error)]
pub enum BuilderError {
#[error("TLS configuration failed: {0}")]
TlsConfigFailed(#[from] std::io::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
#[error("Server failed: {0}")]
ServerFailed(#[from] std::io::Error),
}

View File

@@ -0,0 +1,116 @@
mod builder;
mod error;
pub use builder::ServerBuilder;
pub use error::{BuilderError, ServerError};
use axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use std::net::SocketAddr;
use std::time::Duration;
/// Default timeout for graceful shutdown. After this duration, the server will
/// stop waiting for existing connections and shut down immediately.
const DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
pub struct Server {
pub(crate) router: Router,
pub(crate) address: SocketAddr,
pub(crate) rustls_config: Option<RustlsConfig>,
}
impl Server {
/// Initialize a new server builder.
pub fn builder(router: Router, address: SocketAddr) -> ServerBuilder {
ServerBuilder::new(router, address)
}
/// Get information about the server configuration.
///
/// Note that the address is only the configured address, not the actual address
/// the server is listening on (such as when using port 0).
pub fn info(&self) -> ServerInfo {
ServerInfo {
address: self.address,
protocol: if self.rustls_config.is_some() {
Protocol::Https
} else {
Protocol::Http
},
}
}
/// Serve the application until the shutdown signal is received.
/// Returns the actual address the server is listening on.
pub async fn serve_with_shutdown<F>(
self,
shutdown_signal: F,
) -> Result<(SocketAddr, impl Future<Output = Result<(), ServerError>>), ServerError>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let tokio_listener = tokio::net::TcpListener::bind(self.address).await?;
let actual_addr = tokio_listener.local_addr()?;
let listener = tokio_listener.into_std()?;
let handle = axum_server::Handle::new();
let server_future = async move {
// Spawn a task that triggers graceful shutdown when the signal fires.
// If the server exits before the signal, this task will linger until the
// signal resolves, which is harmless.
let shutdown_handle = tokio::spawn({
let handle = handle.clone();
async move {
shutdown_signal.await;
handle.graceful_shutdown(Some(DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT));
}
});
let result = match self.rustls_config {
Some(config) => {
axum_server::from_tcp_rustls(listener, config)
.handle(handle)
.serve(self.router.into_make_service())
.await
}
None => {
axum_server::from_tcp(listener)
.handle(handle)
.serve(self.router.into_make_service())
.await
}
};
// Abort the shutdown listener if it's still running (server exited first).
shutdown_handle.abort();
result.map_err(ServerError::ServerFailed)
};
Ok((actual_addr, server_future))
}
}
#[derive(Debug, Clone, Copy)]
pub struct ServerInfo {
pub address: SocketAddr,
pub protocol: Protocol,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
Http,
Https,
}
impl std::fmt::Display for Protocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Protocol::Http => write!(f, "http"),
Protocol::Https => write!(f, "https"),
}
}
}

View File

@@ -0,0 +1,9 @@
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Configuration used when serving the HTTP server over TLS.
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
pub cert: PathBuf,
pub key: PathBuf,
}

View File

@@ -405,7 +405,7 @@ impl TaskExecutor {
/// Returns a future that completes when `async-channel::Sender` is dropped or () is sent,
/// which translates to the exit signal being triggered.
pub fn exit(&self) -> impl Future<Output = ()> + 'static {
pub fn exit(&self) -> impl Future<Output = ()> + use<> + 'static {
let exit = self.exit.clone();
async move {
let _ = exit.recv().await;

View File

@@ -8,7 +8,7 @@ edition = { workspace = true }
[dependencies]
bytes = { workspace = true }
eth2 = { workspace = true }
headers = "0.3.2"
headers = "0.4"
reqwest = { workspace = true }
safe_arith = { workspace = true }
serde = { workspace = true }

View File

@@ -1,9 +1,10 @@
use reqwest::StatusCode;
use warp::Rejection;
/// Convert from a "new" `http::StatusCode` to a `warp` compatible one.
pub fn convert(code: StatusCode) -> Result<warp::http::StatusCode, Rejection> {
code.as_u16().try_into().map_err(|e| {
crate::reject::custom_server_error(format!("bad status code {code:?} - {e:?}"))
})
/// Convert a `reqwest::StatusCode` to a `warp::http::StatusCode`.
///
/// In warp 0.4, both `reqwest` (0.12) and `warp` use the `http` v1 crate,
/// so `reqwest::StatusCode` and `warp::http::StatusCode` are the same type.
/// This function is retained for API compatibility but is now a no-op.
pub fn convert(code: StatusCode) -> Result<warp::http::StatusCode, warp::Rejection> {
Ok(code)
}