Writing a basic rate limiter in Rust
2022-09-02I am looking for a job in Rust. Please (contact)[mailto://mail.sandeshbhusal@gmail.com]
Introduction:
Rate limiting algorithms are quintessential in limiting bad actors from DOSing your resources. In this post, I am implementing a very basic rate limiter based on Token bucket algorithm.
The algorithm:
Token bucket works in the following way:
- Start the rate limiter
- For every client request, check if the client has enough "tokens"
- For every request sent from the client, decrement the client's token count.
- Increment the token content for the client in every 1/R seconds where R = request rate = number of request / window interval.
For #4, an example can be 2/5 where 2 requests are allowed every 5 seconds for a particular client.
Implementing the algorithm:
- Let's begin by creating a project.
cargo new --lib ratelimiter
Inside src/lib.rs
we will import everything we need for now.
use std::{collections::HashMap, hash::Hash, time::Instant};
- Then we create a scaffold for the client data. For every client, we need to record two things:
- How many tokens the client has left
- When was the last time the client made a successful request.
For #2.2, we can use Rust's std::time::Instant
since it will let us compute the elapsed time in different granularity as well.
Let's create the client data struct:
struct client_info {
last_sent: Instant,
available_tokens: f64
}
The available_tokens
field is floating point since we will need to do floating-point calculations to increment the tokens.
Let's create a RateLimiter struct that will be exported by this library:
pub struct RateLimiter<T>
where
T: Eq + Hash
{
client_data: HashMap<T, client_info>,
max_allowed: f64,
window_duration: f64,
runnable: fn()
}
To describe:
client_data
is a hashmap that holds all clients' information for client_info. Each client is identified by a unique identifier, which is of the typeT
and needs to implementEq
andHash
traits.max_allowed
is the number of maximum requests allowed in thewindow_duration
.runnable
is a runnable function which is of a very simple signaturefn()
for now.
Let's create a runnable function so that we can test our rate limiter later.
fn runnable() {
println!("I ran!");
}
Now let's go about implementing our client_info
. Since Instant
does not implement Default
we will need to create a trivial new
function for this struct.
impl client_info {
fn new() -> Self {
return Self {
last_request: Instant::now(),
available_tokens: 0.0
}
}
}
Next is the implementation of our RateLimiter
struct.
impl<T> RateLimiter <T>
where
T: Eq + Hash
{
fn new(allowed: f64, time_interval: f64, function_to_run: fn()) -> Self {
Self {
client_data: HashMap::new(),
max_allowd: allowed,
window_duration: time_interval,
runnable: function_to_run,
}
}
}
The new
function returns the RateLimiter
object defaults. The meat of the RateLimiter is in its run
function which we will implement now. The run
function takes in a client ID and runs the runnable
function if the rate limit requirements are passed.
fn run(&mut self, client_id: T) -> Result<(), String> {
// default client entry if the client has not made requests previously.
let mut default_client_entry = client_info::new();
default_client_entry.available_tokens = self.max_allowed;
// extract the entry from client_data hashmap
let entry = self
.client_data
.entry(client_id)
.or_insert(default_client_entry);
// Calculate the tokens available to this client.
entry.available_tokens = entry.available_tokens
+ (entry.last_request.elapsed().as_secs()) as f64
* (self.max_allowed / self.window_duration);
// Cap the token. If client has gathered 10 tokens, e.g and the max allowed requests
// in 10 seconds is just 3 then we cap the available tokens to 3 as well.
if entry.available_tokens >= self.max_allowed {
entry.available_tokens = self.max_allowed;
}
// The client has tokens available. The request can be passed.
if entry.available_tokens >= 1.0 {
// can make a request.
(self.runnable)();
entry.available_tokens -= 1.0;
entry.last_request = Instant::now();
} else {
// The request could not be completed.
Err(String::from("Max rate limit reached!"));
}
Ok(())
}
In order to prevent us from having to change the hashmap every second, we will just calculate the number of tokens that could've been gathered since the last time the function was run successfully. If the token count exceeds the max available request count, then the tokens are capped. Finally, if the available tokens exceed 1 then the request is allowed to pass through. Otherwise we return a string error.
Testing:
Let's write a very simple test for this code. Each client is represented by a client_id which is a usize
.
#[cfg(test)]
mod test_ratelimiter {
use crate::{function_to_execute, RateLimiter, RateLimiterResult};
#[test]
fn test_limits() {
let mut ratelimiter = RateLimiter::<usize>::new(3.0, 10.0, runnable);
for i in 0..10usize {
match ratelimiter.run(1) {
RateLimiterResult::Ok => {
println!("Request {} passed", i);
}
RateLimiterResult::Err(_) => {
println!("Request {} fell through", i);
}
};
// Let us assume that every 1 second, this function runs and makes a request.
std::thread::sleep(std::time::Duration::from_secs(1));
}
}
}
Conclusion:
This is a very basic example of a rate limiter. I had a lot of fun writing this tutorial, especially, setting and executing a function as a structure member. The granularity of this rate limiter can be controlled using as_millis()
instead of as_secs()
in the code to replace seconds with milliseconds.