Skip to content

Commit

Permalink
Merge pull request #352 from Kuadrant/limit_storage
Browse files Browse the repository at this point in the history
First step: Limit storage
  • Loading branch information
alexsnaps authored Jun 10, 2024
2 parents e24873f + 378b95f commit 72c2b12
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 163 deletions.
75 changes: 17 additions & 58 deletions limitador-server/src/envoy_rls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ use opentelemetry::propagation::Extractor;
use std::collections::HashMap;
use std::sync::Arc;

use limitador::CheckResult;
use tonic::codegen::http::HeaderMap;
use tonic::{transport, transport::Server, Request, Response, Status};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;

use limitador::counter::Counter;

use crate::envoy_rls::server::envoy::config::core::v3::HeaderValue;
use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_response::Code;
use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_service_server::{
Expand All @@ -29,6 +28,21 @@ pub enum RateLimitHeaders {
DraftVersion03,
}

impl RateLimitHeaders {
pub fn headers(&self, response: &mut CheckResult) -> Vec<HeaderValue> {
let mut headers = match self {
RateLimitHeaders::None => Vec::default(),
RateLimitHeaders::DraftVersion03 => response
.response_header()
.into_iter()
.map(|(key, value)| HeaderValue { key, value })
.collect(),
};
headers.sort_by(|a, b| a.key.cmp(&b.key));
headers
}
}

pub struct MyRateLimiter {
limiter: Arc<Limiter>,
rate_limit_headers: RateLimitHeaders,
Expand Down Expand Up @@ -142,10 +156,7 @@ impl RateLimitService for MyRateLimiter {
overall_code: resp_code.into(),
statuses: vec![],
request_headers_to_add: vec![],
response_headers_to_add: to_response_header(
&self.rate_limit_headers,
&mut rate_limited_resp.counters,
),
response_headers_to_add: self.rate_limit_headers.headers(&mut rate_limited_resp),
raw_body: vec![],
dynamic_metadata: None,
quota: None,
Expand All @@ -155,58 +166,6 @@ impl RateLimitService for MyRateLimiter {
}
}

pub fn to_response_header(
rate_limit_headers: &RateLimitHeaders,
counters: &mut [Counter],
) -> Vec<HeaderValue> {
let mut headers = Vec::new();
match rate_limit_headers {
RateLimitHeaders::None => {}

// creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html
RateLimitHeaders::DraftVersion03 => {
// sort by the limit remaining..
counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});

let mut all_limits_text = String::with_capacity(20 * counters.len());
counters.iter_mut().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text
.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});

if let Some(counter) = counters.first() {
headers.push(HeaderValue {
key: "X-RateLimit-Limit".to_string(),
value: format!("{}{}", counter.max_value(), all_limits_text),
});

let remaining = counter.remaining().unwrap_or(counter.max_value());
headers.push(HeaderValue {
key: "X-RateLimit-Remaining".to_string(),
value: format!("{}", remaining),
});

if let Some(duration) = counter.expires_in() {
headers.push(HeaderValue {
key: "X-RateLimit-Reset".to_string(),
value: format!("{}", duration.as_secs()),
});
}
}
}
};
headers
}

struct RateLimitRequestHeaders {
inner: HeaderMap,
}
Expand Down
54 changes: 14 additions & 40 deletions limitador-server/src/http_api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prometheus_metrics::PrometheusMetrics;
use crate::Limiter;
use actix_web::{http::StatusCode, HttpResponse, HttpResponseBuilder, ResponseError};
use actix_web::{App, HttpServer};
use limitador::CheckResult;
use paperclip::actix::{
api_v2_errors,
api_v2_operation,
Expand Down Expand Up @@ -209,7 +210,7 @@ async fn check_and_report(
add_response_header(
&mut resp,
response_headers.as_str(),
&mut is_rate_limited.counters,
&mut is_rate_limited,
);
resp.json(())
}
Expand All @@ -224,7 +225,7 @@ async fn check_and_report(
add_response_header(
&mut resp,
response_headers.as_str(),
&mut is_rate_limited.counters,
&mut is_rate_limited,
);
resp.json(())
}
Expand All @@ -238,48 +239,21 @@ async fn check_and_report(
pub fn add_response_header(
resp: &mut HttpResponseBuilder,
rate_limit_headers: &str,
counters: &mut [limitador::counter::Counter],
result: &mut CheckResult,
) {
match rate_limit_headers {
if rate_limit_headers == "DraftVersion03" {
// creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html
"DraftVersion03" => {
// sort by the limit remaining..
counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});

let mut all_limits_text = String::with_capacity(20 * counters.len());
counters.iter_mut().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text
.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});

if let Some(counter) = counters.first() {
resp.insert_header((
"X-RateLimit-Limit",
format!("{}{}", counter.max_value(), all_limits_text),
));

let remaining = counter.remaining().unwrap_or(counter.max_value());
resp.insert_header((
"X-RateLimit-Remaining".to_string(),
format!("{}", remaining),
));

if let Some(duration) = counter.expires_in() {
resp.insert_header(("X-RateLimit-Reset", format!("{}", duration.as_secs())));
}
let headers = result.response_header();
if let Some(limit) = headers.get("X-RateLimit-Limit") {
resp.insert_header(("X-RateLimit-Limit", limit.clone()));
}
if let Some(remaining) = headers.get("X-RateLimit-Remaining") {
resp.insert_header(("X-RateLimit-Remaining".to_string(), remaining.clone()));
if let Some(duration) = headers.get("X-RateLimit-Reset") {
resp.insert_header(("X-RateLimit-Reset", duration.clone()));
}
}
_default => {}
};
}
}

pub async fn run_http_server(
Expand Down
17 changes: 8 additions & 9 deletions limitador/src/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use crate::limit::{Limit, Namespace};
use serde::{Deserialize, Serialize, Serializer};
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;

#[derive(Eq, Clone, Debug, Serialize, Deserialize)]
pub struct Counter {
limit: Limit,
limit: Arc<Limit>,

// Need to sort to generate the same object when using the JSON as a key or
// value in Redis.
Expand All @@ -26,9 +27,10 @@ where
}

impl Counter {
pub fn new(limit: Limit, set_variables: HashMap<String, String>) -> Self {
pub fn new<L: Into<Arc<Limit>>>(limit: L, set_variables: HashMap<String, String>) -> Self {
// TODO: check that all the variables defined in the limit are set.

let limit = limit.into();
let mut vars = set_variables;
vars.retain(|var, _| limit.has_variable(var));

Expand All @@ -43,7 +45,7 @@ impl Counter {
#[cfg(any(feature = "redis_storage", feature = "disk_storage"))]
pub(crate) fn key(&self) -> Self {
Self {
limit: self.limit.clone(),
limit: Arc::clone(&self.limit),
set_variables: self.set_variables.clone(),
remaining: None,
expires_in: None,
Expand All @@ -58,12 +60,9 @@ impl Counter {
self.limit.max_value()
}

pub fn update_to_limit(&mut self, limit: &Limit) -> bool {
if limit == &self.limit {
self.limit.set_max_value(limit.max_value());
if let Some(name) = limit.name() {
self.limit.set_name(name.to_string());
}
pub fn update_to_limit(&mut self, limit: Arc<Limit>) -> bool {
if limit == self.limit {
self.limit = limit;
return true;
}
false
Expand Down
64 changes: 58 additions & 6 deletions limitador/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
#![allow(clippy::multiple_crate_versions)]

use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::counter::Counter;
use crate::errors::LimitadorError;
Expand Down Expand Up @@ -226,6 +227,49 @@ pub struct CheckResult {
pub limit_name: Option<String>,
}

impl CheckResult {
pub fn response_header(&mut self) -> HashMap<String, String> {
let mut headers = HashMap::new();
// sort by the limit remaining..
self.counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});

let mut all_limits_text = String::with_capacity(20 * self.counters.len());
self.counters.iter_mut().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});

if let Some(counter) = self.counters.first() {
headers.insert(
"X-RateLimit-Limit".to_string(),
format!("{}{}", counter.max_value(), all_limits_text),
);

let remaining = counter.remaining().unwrap_or(counter.max_value());
headers.insert(
"X-RateLimit-Remaining".to_string(),
format!("{}", remaining),
);

if let Some(duration) = counter.expires_in() {
headers.insert(
"X-RateLimit-Reset".to_string(),
format!("{}", duration.as_secs()),
);
}
}
headers
}
}

impl From<CheckResult> for bool {
fn from(value: CheckResult) -> Self {
value.limited
Expand Down Expand Up @@ -298,7 +342,11 @@ impl RateLimiter {
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}

pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
Expand Down Expand Up @@ -432,12 +480,12 @@ impl RateLimiter {
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let limits = self.storage.get_limits(namespace);

let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.map(|lim| Counter::new(Arc::clone(lim), values.clone()))
.collect();

Ok(counters)
Expand Down Expand Up @@ -470,7 +518,11 @@ impl AsyncRateLimiter {
}

pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}

pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
Expand Down Expand Up @@ -610,12 +662,12 @@ impl AsyncRateLimiter {
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let limits = self.storage.get_limits(namespace);

let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.map(|lim| Counter::new(Arc::clone(lim), values.clone()))
.collect();

Ok(counters)
Expand Down
12 changes: 7 additions & 5 deletions limitador/src/storage/disk/rocksdb_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use rocksdb::{
DB,
};
use std::collections::{BTreeSet, HashSet};
use std::ops::Deref;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tracing::trace_span;

Expand Down Expand Up @@ -91,7 +93,7 @@ impl CounterStorage for RocksDbStorage {
}

#[tracing::instrument(skip_all)]
fn get_counters(&self, limits: &HashSet<Limit>) -> Result<HashSet<Counter>, StorageErr> {
fn get_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<HashSet<Counter>, StorageErr> {
let mut counters = HashSet::default();
let namepaces: BTreeSet<&str> = limits.iter().map(|l| l.namespace().as_ref()).collect();
for ns in namepaces {
Expand All @@ -113,8 +115,8 @@ impl CounterStorage for RocksDbStorage {
}
let value: ExpiringValue = value.as_ref().try_into()?;
for limit in limits {
if limit == counter.limit() {
counter.update_to_limit(limit);
if limit.deref() == counter.limit() {
counter.update_to_limit(Arc::clone(limit));
let ttl = value.ttl();
counter.set_expires_in(ttl);
counter.set_remaining(limit.max_value() - value.value());
Expand All @@ -133,8 +135,8 @@ impl CounterStorage for RocksDbStorage {
}

#[tracing::instrument(skip_all)]
fn delete_counters(&self, limits: HashSet<Limit>) -> Result<(), StorageErr> {
let counters = self.get_counters(&limits)?;
fn delete_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<(), StorageErr> {
let counters = self.get_counters(limits)?;
for counter in &counters {
let span = trace_span!("datastore");
let _entered = span.enter();
Expand Down
Loading

0 comments on commit 72c2b12

Please sign in to comment.