Skip to content

Instantly share code, notes, and snippets.

@lmammino
Created May 1, 2024 17:51
Show Gist options
  • Save lmammino/5ff73a8f36deadda7287a2e0ff6b287c to your computer and use it in GitHub Desktop.
Save lmammino/5ff73a8f36deadda7287a2e0ff6b287c to your computer and use it in GitHub Desktop.
Poor man async rust HTTP load testing tool
[package]
name = "send-requests"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
chrono = "0.4.38"
clap = { version = "4.5.4", features = ["derive"] }
jsonwebtoken = "9.3.0"
rand = "0.8.5"
reqwest = { version = "0.12.4", default-features = false, features = [
"rustls-tls",
"http2",
] }
serde = "1.0.199"
serde_json = "1.0.116"
tokio = { version = "1.37.0", features = ["full"] }
use clap::Parser;
use reqwest::Url;
use stats::Stats;
use std::sync::Arc;
use tokengen::TokenGen;
use tokio::sync::RwLock;
mod stats;
mod tokengen;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "10000")]
total_requests: usize,
#[arg(short, long, default_value = "100")]
requests_per_sec: usize,
#[arg(short = 'u', long)]
target_url: String,
}
#[tokio::main]
async fn main() {
let args = Args::parse();
let target_url = Arc::new(Url::parse(&args.target_url).expect("Invalid target URL"));
let client = Arc::new(reqwest::Client::new());
let token_gen = TokenGen::default();
let stats = Arc::new(RwLock::new(Stats::default()));
while stats.read().await.sent_requests() < args.total_requests {
let loop_start = tokio::time::Instant::now();
let curr_requests_sent = stats.read().await.sent_requests();
let num_requests_in_batch =
(args.total_requests - curr_requests_sent).min(args.requests_per_sec);
for _ in 0..num_requests_in_batch {
tokio::spawn({
let token = token_gen.next_token();
let client = client.clone();
let target_url = target_url.clone();
let stats = stats.clone();
async move {
let request = client
.get(target_url.to_string())
.header("Authorization", format!("Bearer {}", token))
.send();
stats.write().await.inc_requests_sent();
let response = request.await.unwrap();
let status_code = response.status().as_u16();
response.bytes().await.unwrap(); // consumes the response body
stats.write().await.inc_status_code(status_code);
stats.write().await.inc_completed_requests();
}
});
}
println!("{}", stats.read().await);
// waits the remainder of 1second since the last loop iteration before starting a new one
let loop_end = tokio::time::Instant::now();
let loop_duration = loop_end - loop_start;
let sleep_duration = (1000 - loop_duration.as_millis()).max(0) as u64;
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_duration)).await;
}
println!("{}", stats.read().await);
}
use std::{collections::HashMap, fmt::Display};
pub struct Stats {
requests_sent: usize,
completed_requests: usize,
status_codes_count: HashMap<&'static str, usize>,
}
impl Default for Stats {
fn default() -> Self {
let mut status_codes_count = HashMap::with_capacity(5);
status_codes_count.insert("1xx", 0);
status_codes_count.insert("2xx", 0);
status_codes_count.insert("3xx", 0);
status_codes_count.insert("4xx", 0);
status_codes_count.insert("5xx", 0);
Stats {
requests_sent: 0,
completed_requests: 0,
status_codes_count,
}
}
}
impl Display for Stats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Requests sent: {}\nCompleted requests: {}\nStatus codes count: {:?}",
self.requests_sent, self.completed_requests, self.status_codes_count
)
}
}
impl Stats {
pub fn sent_requests(&self) -> usize {
self.requests_sent
}
pub fn inc_requests_sent(&mut self) {
self.requests_sent += 1;
}
pub fn inc_completed_requests(&mut self) {
self.completed_requests += 1;
}
pub fn inc_status_code(&mut self, status_code: u16) {
let status_code_category = match status_code {
100..=199 => "1xx",
200..=299 => "2xx",
300..=399 => "3xx",
400..=499 => "4xx",
500..=599 => "5xx",
_ => panic!("Invalid status code"),
};
*self
.status_codes_count
.get_mut(status_code_category)
.unwrap() += 1;
}
}
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use rand::prelude::*;
use serde_json::json;
use std::sync::atomic;
pub struct TokenGen {
keys: Vec<EncodingKey>,
generated_keys: atomic::AtomicUsize,
}
impl Default for TokenGen {
fn default() -> Self {
let keys = [
EncodingKey::from_rsa_pem(include_bytes!("../../keys/key0/private.pem")).unwrap(),
EncodingKey::from_rsa_pem(include_bytes!("../../keys/key1/private.pem")).unwrap(),
EncodingKey::from_rsa_pem(include_bytes!("../../keys/key2/private.pem")).unwrap(),
EncodingKey::from_rsa_pem(include_bytes!("../../keys/key3/private.pem")).unwrap(),
];
TokenGen {
keys: keys.to_vec(),
generated_keys: atomic::AtomicUsize::new(0),
}
}
}
impl TokenGen {
pub fn next_token(&self) -> String {
let i: usize = self.generated_keys.fetch_add(1, atomic::Ordering::SeqCst);
let kid = format!("key{}", i % 4);
let key = &self.keys[i % self.keys.len()];
let token_header: Header =
serde_json::from_value(json!({"alg": Algorithm::RS512, "kid": kid})).unwrap();
let exp = (Utc::now() + Duration::try_hours(1).unwrap()).timestamp();
let jti: u64 = random();
let token_claims = json!({
"jti": jti, // adds a random token id to make it unlikely to generate the same token and avoid caching
"aud": "oidc-authorizer-benchmark",
"iss": "oidc-authorizer-benchmark",
"sub": "oidc-benchmark-test-user",
"exp": exp
});
encode(&token_header, &token_claims, key).unwrap()
}
}
impl Iterator for TokenGen {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
Some(self.next_token())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment