Created
December 19, 2022 19:14
-
-
Save matsadler/e3bfe27e0af90b7a3bd7cdbe2b68fe50 to your computer and use it in GitHub Desktop.
A rate limiter based on the Generic Cell Rate Algorithm.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Rate limiter. | |
//! | |
//! This module contains a rate limiter based on the [Generic Cell Rate | |
//! Algorithm][gcra]. | |
//! | |
//! [gcra]: https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm | |
use std::{ | |
cmp::max, | |
collections::{HashMap, VecDeque}, | |
fmt::Debug, | |
hash::Hash, | |
sync::{Arc, Mutex}, | |
}; | |
use chrono::{DateTime, Duration, Utc}; | |
use log::{info, trace}; | |
/// Wraps a HashMap that can store only time, with a start and end. | |
struct Window<T> { | |
/// inclusive window start time | |
start: DateTime<Utc>, | |
/// exclusive window end time | |
end: DateTime<Utc>, | |
data: HashMap<T, DateTime<Utc>>, | |
} | |
impl<T: Hash + Eq> Window<T> { | |
/// Create a new Window. | |
fn new(start: DateTime<Utc>, interval: Duration) -> Window<T> { | |
Window { | |
start, | |
end: start + interval, | |
data: HashMap::new(), | |
} | |
} | |
} | |
/// A key-to-timestamp map that expires old values. | |
/// | |
/// Implements the efficient expiry of old values by storing data in a number | |
/// of windows, simply dropping the oldest window when a new one must be added. | |
/// | |
/// It is assumed time values will be in the future, but not greater than | |
/// `interval * num_windows` in the future. | |
/// | |
/// This comes at the cost of possibly having to check multiple windows during | |
/// lookups. | |
struct Timestamps<T> { | |
/// The size of each window. | |
interval: Duration, | |
/// The maximum number of windows. | |
num_windows: usize, | |
/// The windows themselves, in a ring buffer. | |
windows: VecDeque<Window<T>>, | |
} | |
impl<T: Hash + Eq + Debug> Timestamps<T> { | |
/// Create a new Timestamps. | |
/// | |
/// # Examples | |
/// | |
/// ``` | |
/// let timestamps = Timestamps::new(Duration::seconds(1), 5); | |
/// let soon = Utc::now() + Duration::seconds(1); | |
/// timestamps.insert("foo", soon); | |
/// ``` | |
fn new(interval: Duration, num_windows: usize) -> Timestamps<T> { | |
Timestamps { | |
interval, | |
num_windows, | |
windows: VecDeque::new(), | |
} | |
} | |
/// Returns a reference to the time value for the given key. | |
/// | |
/// # Examples | |
/// | |
/// ``` | |
/// let mut timestamps = Timestamps::new(Duration::seconds(1), 5); | |
/// let soon = Utc::now() + Duration::seconds(1); | |
/// timestamps.insert("foo", soon.clone()); | |
/// | |
/// assert_eq!(timestamps.get("foo"), Some(&soon)); | |
/// ``` | |
fn get<Q>(&self, key: &Q) -> Option<&DateTime<Utc>> | |
where | |
T: std::borrow::Borrow<Q>, | |
Q: Hash + Eq + ?Sized, | |
{ | |
self.windows.iter().find_map(|w| w.data.get(key)) | |
} | |
/// Insert a time value for the given key. | |
/// | |
/// The time value may not be inerted if it is in the past. | |
/// | |
/// Inserting a new time value may expire old values. | |
/// | |
/// Inserting a new time value more than `interval * num_windows` in the | |
/// future may expire values early. | |
/// | |
/// # Examples | |
/// | |
/// ``` | |
/// let timestamps = Timestamps::new(Duration::seconds(1), 5); | |
/// let soon = Utc::now() + Duration::seconds(1); | |
/// timestamps.insert("foo", soon); | |
/// ``` | |
fn insert(&mut self, key: T, value: DateTime<Utc>) { | |
let max = self.windows.front().map(|f| f.end).unwrap_or_else(Utc::now); | |
if value < max { | |
let mut window = self | |
.windows | |
.iter_mut() | |
.find(|w| w.start <= value && value < w.end); | |
if let Some(w) = window.as_mut() { | |
trace!( | |
"inserting {:?} {} into window {}..{}", | |
key, | |
value, | |
w.start, | |
w.end | |
); | |
w.data.insert(key, value); | |
} | |
} else { | |
let mut window = Window::new(max, self.interval); | |
trace!( | |
"inserting {:?} {} into new window {}..{}", | |
key, | |
value, | |
window.start, | |
window.end | |
); | |
window.data.insert(key, value); | |
self.windows.push_front(window); | |
if self.windows.len() > self.num_windows { | |
let w = self.windows.pop_back().unwrap(); | |
trace!("expiring window {}..{}", w.start, w.end); | |
} | |
} | |
} | |
} | |
/// A GCRA based rate limiter. | |
/// | |
/// Limiter can be queried for whether or not a request should be allowed. A | |
/// successful query is counted against the limit, an unsuccessful one is not. | |
/// | |
/// Limiter can be cloned to share the same limit between threads. | |
/// | |
/// Internally the Limiter stores only the next allowable request time per key, | |
/// and expires old values as new ones are added, so should be reasonably | |
/// memory efficient. | |
/// | |
/// # Examples | |
/// | |
/// ``` | |
/// let limiter = Limiter::new(Duration::seconds(1), 1); | |
/// | |
/// assert!(!limiter.is_limited("[email protected]")); | |
/// assert!(limiter.is_limited("[email protected]")); | |
/// ``` | |
#[derive(Clone)] | |
pub struct Limiter<T> { | |
interval: Duration, | |
burst: usize, | |
timestamps: Arc<Mutex<Timestamps<T>>>, | |
} | |
impl<T: Hash + Eq + Debug> Limiter<T> { | |
/// Create a new Limiter. | |
/// | |
/// The Limiter will allow one request every `interval`, with a burst | |
/// capacity of `burst`, e.g. an interval of 1 second and a burst of 5 | |
/// requests, then 1 request a second after that. The rate of requests | |
/// would have to drop below 1 per-second to 'refill' the burst capacity. | |
/// Alternately it would allow for 10 requests in 5 seconds, and then 1 | |
/// request a second after that. | |
pub fn new(interval: std::time::Duration, burst: usize) -> Limiter<T> { | |
// panics if interval is greater than ~292 billion years | |
let interval = Duration::from_std(interval).unwrap(); | |
Limiter { | |
interval, | |
burst, | |
timestamps: Arc::new(Mutex::new(Timestamps::new(interval, burst))), | |
} | |
} | |
/// Returns `true` if the limit has been reached and the request should be | |
/// denied. Denied requests are not counted against the limit. Returns | |
/// `false` when a request should be allowed and has been counted. | |
pub fn is_limited(&self, key: T) -> bool { | |
let now = Utc::now(); | |
let mut timestamps = self.timestamps.lock().unwrap(); | |
let theoretical_arrival = timestamps.get(&key).map(|ts| max(&now, ts)).unwrap_or(&now); | |
let next_arrival_time = *theoretical_arrival + self.interval; | |
if now < next_arrival_time - self.interval * self.burst as i32 { | |
info!("{:?} ratelimited", key); | |
true | |
} else { | |
timestamps.insert(key, next_arrival_time); | |
false | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment