Last active
December 17, 2023 02:11
-
-
Save GlenDC/7d032afe41ffe03946ca55c6ba68be30 to your computer and use it in GitHub Desktop.
simple hyper (0.14) tower-based http router
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "router" | |
version = "0.1.0" | |
edition = "2021" | |
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |
[dependencies] | |
hyper = { version = "0.14.27", features = ["full"] } | |
tower = { version = "0.4.12", features = ["full"] } | |
tower-http = { version = "0.4.4", features = ["full"] } | |
tokio = { version = "1", features = ["full"] } | |
http-body = "0.4.6" | |
serde_json = "1.0" | |
serde = { version = "1.0", features = ["derive"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use hyper::{ | |
body::Body, | |
header::{HeaderValue, CONTENT_TYPE}, | |
Method, Request, Response, Server, StatusCode, | |
}; | |
use std::{collections::HashMap, net::SocketAddr}; | |
use std::{ | |
convert::Infallible, | |
future::Future, | |
pin::Pin, | |
task::{Context, Poll}, | |
}; | |
use tower::{make::Shared, service_fn, util::BoxCloneService, Service, ServiceBuilder, ServiceExt}; | |
use tower_http::{ | |
normalize_path::NormalizePathLayer, validate_request::ValidateRequestHeaderLayer, | |
}; | |
/**************************** | |
* Type Aliases | |
***************************/ | |
pub type WebRequest = Request<Body>; | |
pub type WebResponse = Response<Body>; | |
/**************************** | |
* IntoWebResponse | |
***************************/ | |
pub trait IntoWebResponse { | |
fn into_web_response(self) -> WebResponse; | |
} | |
impl IntoWebResponse for WebResponse { | |
fn into_web_response(self) -> WebResponse { | |
self | |
} | |
} | |
impl IntoWebResponse for Infallible { | |
fn into_web_response(self) -> WebResponse { | |
panic!("BUG"); | |
} | |
} | |
impl IntoWebResponse for StatusCode { | |
fn into_web_response(self) -> WebResponse { | |
Response::builder() | |
.status(self) | |
.body(Body::empty()) | |
.expect("the StatusCode web response to be build") | |
} | |
} | |
impl IntoWebResponse for &'static str { | |
fn into_web_response(self) -> WebResponse { | |
Response::builder() | |
.status(StatusCode::OK) | |
.header(CONTENT_TYPE, HeaderValue::from_static("text/plain")) | |
.body(Body::from(self)) | |
.expect("the &'static str web response to be build") | |
} | |
} | |
impl IntoWebResponse for String { | |
fn into_web_response(self) -> WebResponse { | |
Response::builder() | |
.status(StatusCode::OK) | |
.header(CONTENT_TYPE, HeaderValue::from_static("text/plain")) | |
.body(Body::from(self)) | |
.expect("the &'static str web response to be build") | |
} | |
} | |
impl IntoWebResponse for Box<dyn std::error::Error> { | |
fn into_web_response(self) -> WebResponse { | |
Response::builder() | |
.status(StatusCode::INTERNAL_SERVER_ERROR) | |
.header(CONTENT_TYPE, HeaderValue::from_static("text/plain")) | |
.body(Body::from(self.to_string())) | |
.expect("the Boxed error web response to be build") | |
} | |
} | |
/**************************** | |
* Router | |
***************************/ | |
type RouterKey = (Method, &'static str); | |
type RouterService = BoxCloneService<WebRequest, WebResponse, WebResponse>; | |
#[derive(Debug, Default, Clone)] | |
pub struct Router { | |
endpoints: HashMap<RouterKey, RouterService>, | |
} | |
impl Router { | |
pub fn on<R, E>( | |
&mut self, | |
method: Method, | |
endpoint: &'static str, | |
svc: BoxCloneService<WebRequest, R, E>, | |
) where | |
R: IntoWebResponse + 'static, | |
E: IntoWebResponse + 'static, | |
{ | |
let svc = BoxCloneService::new( | |
svc.map_response(IntoWebResponse::into_web_response) | |
.map_err(IntoWebResponse::into_web_response), | |
); | |
self.endpoints.insert((method, endpoint), svc); | |
} | |
} | |
impl Service<WebRequest> for Router { | |
type Response = WebResponse; | |
type Error = Infallible; | |
type Future = | |
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; | |
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { | |
Poll::Ready(Ok(())) | |
} | |
fn call(&mut self, req: WebRequest) -> Self::Future { | |
match self | |
.endpoints | |
.get(&(req.method().clone(), req.uri().path())) | |
{ | |
Some(svc) => { | |
let mut svc = svc.clone(); | |
let fut = async move { | |
let ready_svc = match svc.ready().await { | |
Ok(svc) => svc, | |
Err(_) => return Ok(StatusCode::TOO_MANY_REQUESTS.into_web_response()), | |
}; | |
match ready_svc.call(req).await { | |
Ok(res) => Ok(res), | |
Err(e) => Ok(e), | |
} | |
}; | |
Box::pin(fut) | |
} | |
None => Box::pin(async { Ok(StatusCode::NOT_FOUND.into_web_response()) }), | |
} | |
} | |
} | |
/**************************** | |
* Endpoints | |
***************************/ | |
async fn svc_hello(_req: WebRequest) -> Result<&'static str, Infallible> { | |
Ok("Hello, World!") | |
} | |
#[derive(Debug, serde::Deserialize)] | |
struct Person { | |
name: String, | |
} | |
async fn svc_json(req: WebRequest) -> Result<String, StatusCode> { | |
let full_body = hyper::body::to_bytes(req.into_body()) | |
.await | |
.map_err(|_| StatusCode::BAD_REQUEST)?; | |
serde_json::from_slice::<Person>(&full_body) | |
.map(|person| format!("Hello, {}!", person.name)) | |
.map_err(|_| StatusCode::BAD_REQUEST) | |
} | |
/**************************** | |
* App | |
***************************/ | |
#[tokio::main] | |
async fn main() -> Result<(), Box<dyn std::error::Error>> { | |
let mut router = Router::default(); | |
router.on(Method::GET, "/hello", service_fn(svc_hello).boxed_clone()); | |
router.on( | |
Method::POST, | |
"/foo/bar/json", | |
ServiceBuilder::new() | |
.boxed_clone() | |
.layer(ValidateRequestHeaderLayer::bearer("passwordlol")) | |
.map_response(IntoWebResponse::into_web_response) | |
.service_fn(svc_json), | |
); | |
let hyper_service = ServiceBuilder::new() | |
.layer(NormalizePathLayer::trim_trailing_slash()) | |
.service(router); | |
// Construct our SocketAddr to listen on... | |
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); | |
// Shared is a MakeService that produces services by cloning an inner service... | |
let make_service = Shared::new(hyper_service); | |
// Then bind and serve... | |
let server = Server::bind(&addr).serve(make_service); | |
// And run forever... | |
if let Err(e) = server.await { | |
eprintln!("server error: {}", e); | |
} | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment