diff --git a/limitador-server/src/envoy_rls/server.rs b/limitador-server/src/envoy_rls/server.rs index e6d5f08f..fafefa2c 100644 --- a/limitador-server/src/envoy_rls/server.rs +++ b/limitador-server/src/envoy_rls/server.rs @@ -3,13 +3,12 @@ use opentelemetry::propagation::Extractor; use std::collections::HashMap; use std::sync::Arc; +use limitador::CheckResult; use tonic::codegen::http::HeaderMap; use tonic::{transport, transport::Server, Request, Response, Status}; use tracing::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; -use limitador::counter::Counter; - use crate::envoy_rls::server::envoy::config::core::v3::HeaderValue; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_response::Code; use crate::envoy_rls::server::envoy::service::ratelimit::v3::rate_limit_service_server::{ @@ -29,6 +28,21 @@ pub enum RateLimitHeaders { DraftVersion03, } +impl RateLimitHeaders { + pub fn headers(&self, response: &mut CheckResult) -> Vec { + let mut headers = match self { + RateLimitHeaders::None => Vec::default(), + RateLimitHeaders::DraftVersion03 => response + .response_header() + .into_iter() + .map(|(key, value)| HeaderValue { key, value }) + .collect(), + }; + headers.sort_by(|a, b| a.key.cmp(&b.key)); + headers + } +} + pub struct MyRateLimiter { limiter: Arc, rate_limit_headers: RateLimitHeaders, @@ -142,10 +156,7 @@ impl RateLimitService for MyRateLimiter { overall_code: resp_code.into(), statuses: vec![], request_headers_to_add: vec![], - response_headers_to_add: to_response_header( - &self.rate_limit_headers, - &mut rate_limited_resp.counters, - ), + response_headers_to_add: self.rate_limit_headers.headers(&mut rate_limited_resp), raw_body: vec![], dynamic_metadata: None, quota: None, @@ -155,58 +166,6 @@ impl RateLimitService for MyRateLimiter { } } -pub fn to_response_header( - rate_limit_headers: &RateLimitHeaders, - counters: &mut [Counter], -) -> Vec { - let mut headers = Vec::new(); - match rate_limit_headers { - RateLimitHeaders::None => {} - - // creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html - RateLimitHeaders::DraftVersion03 => { - // sort by the limit remaining.. - counters.sort_by(|a, b| { - let a_remaining = a.remaining().unwrap_or(a.max_value()); - let b_remaining = b.remaining().unwrap_or(b.max_value()); - a_remaining.cmp(&b_remaining) - }); - - let mut all_limits_text = String::with_capacity(20 * counters.len()); - counters.iter_mut().for_each(|counter| { - all_limits_text.push_str( - format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(), - ); - if let Some(name) = counter.limit().name() { - all_limits_text - .push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); - } - }); - - if let Some(counter) = counters.first() { - headers.push(HeaderValue { - key: "X-RateLimit-Limit".to_string(), - value: format!("{}{}", counter.max_value(), all_limits_text), - }); - - let remaining = counter.remaining().unwrap_or(counter.max_value()); - headers.push(HeaderValue { - key: "X-RateLimit-Remaining".to_string(), - value: format!("{}", remaining), - }); - - if let Some(duration) = counter.expires_in() { - headers.push(HeaderValue { - key: "X-RateLimit-Reset".to_string(), - value: format!("{}", duration.as_secs()), - }); - } - } - } - }; - headers -} - struct RateLimitRequestHeaders { inner: HeaderMap, } diff --git a/limitador-server/src/http_api/server.rs b/limitador-server/src/http_api/server.rs index 97937d69..bc3b91e6 100644 --- a/limitador-server/src/http_api/server.rs +++ b/limitador-server/src/http_api/server.rs @@ -3,6 +3,7 @@ use crate::prometheus_metrics::PrometheusMetrics; use crate::Limiter; use actix_web::{http::StatusCode, HttpResponse, HttpResponseBuilder, ResponseError}; use actix_web::{App, HttpServer}; +use limitador::CheckResult; use paperclip::actix::{ api_v2_errors, api_v2_operation, @@ -209,7 +210,7 @@ async fn check_and_report( add_response_header( &mut resp, response_headers.as_str(), - &mut is_rate_limited.counters, + &mut is_rate_limited, ); resp.json(()) } @@ -224,7 +225,7 @@ async fn check_and_report( add_response_header( &mut resp, response_headers.as_str(), - &mut is_rate_limited.counters, + &mut is_rate_limited, ); resp.json(()) } @@ -238,48 +239,21 @@ async fn check_and_report( pub fn add_response_header( resp: &mut HttpResponseBuilder, rate_limit_headers: &str, - counters: &mut [limitador::counter::Counter], + result: &mut CheckResult, ) { - match rate_limit_headers { + if rate_limit_headers == "DraftVersion03" { // creates response headers per https://datatracker.ietf.org/doc/id/draft-polli-ratelimit-headers-03.html - "DraftVersion03" => { - // sort by the limit remaining.. - counters.sort_by(|a, b| { - let a_remaining = a.remaining().unwrap_or(a.max_value()); - let b_remaining = b.remaining().unwrap_or(b.max_value()); - a_remaining.cmp(&b_remaining) - }); - - let mut all_limits_text = String::with_capacity(20 * counters.len()); - counters.iter_mut().for_each(|counter| { - all_limits_text.push_str( - format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(), - ); - if let Some(name) = counter.limit().name() { - all_limits_text - .push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); - } - }); - - if let Some(counter) = counters.first() { - resp.insert_header(( - "X-RateLimit-Limit", - format!("{}{}", counter.max_value(), all_limits_text), - )); - - let remaining = counter.remaining().unwrap_or(counter.max_value()); - resp.insert_header(( - "X-RateLimit-Remaining".to_string(), - format!("{}", remaining), - )); - - if let Some(duration) = counter.expires_in() { - resp.insert_header(("X-RateLimit-Reset", format!("{}", duration.as_secs()))); - } + let headers = result.response_header(); + if let Some(limit) = headers.get("X-RateLimit-Limit") { + resp.insert_header(("X-RateLimit-Limit", limit.clone())); + } + if let Some(remaining) = headers.get("X-RateLimit-Remaining") { + resp.insert_header(("X-RateLimit-Remaining".to_string(), remaining.clone())); + if let Some(duration) = headers.get("X-RateLimit-Reset") { + resp.insert_header(("X-RateLimit-Reset", duration.clone())); } } - _default => {} - }; + } } pub async fn run_http_server( diff --git a/limitador/src/counter.rs b/limitador/src/counter.rs index 9763d627..5f5bac49 100644 --- a/limitador/src/counter.rs +++ b/limitador/src/counter.rs @@ -2,11 +2,12 @@ use crate::limit::{Limit, Namespace}; use serde::{Deserialize, Serialize, Serializer}; use std::collections::{BTreeMap, HashMap}; use std::hash::{Hash, Hasher}; +use std::sync::Arc; use std::time::Duration; #[derive(Eq, Clone, Debug, Serialize, Deserialize)] pub struct Counter { - limit: Limit, + limit: Arc, // Need to sort to generate the same object when using the JSON as a key or // value in Redis. @@ -26,9 +27,10 @@ where } impl Counter { - pub fn new(limit: Limit, set_variables: HashMap) -> Self { + pub fn new>>(limit: L, set_variables: HashMap) -> Self { // TODO: check that all the variables defined in the limit are set. + let limit = limit.into(); let mut vars = set_variables; vars.retain(|var, _| limit.has_variable(var)); @@ -43,7 +45,7 @@ impl Counter { #[cfg(any(feature = "redis_storage", feature = "disk_storage"))] pub(crate) fn key(&self) -> Self { Self { - limit: self.limit.clone(), + limit: Arc::clone(&self.limit), set_variables: self.set_variables.clone(), remaining: None, expires_in: None, @@ -58,12 +60,9 @@ impl Counter { self.limit.max_value() } - pub fn update_to_limit(&mut self, limit: &Limit) -> bool { - if limit == &self.limit { - self.limit.set_max_value(limit.max_value()); - if let Some(name) = limit.name() { - self.limit.set_name(name.to_string()); - } + pub fn update_to_limit(&mut self, limit: Arc) -> bool { + if limit == self.limit { + self.limit = limit; return true; } false diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 59f07a67..a71de204 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -193,6 +193,7 @@ #![allow(clippy::multiple_crate_versions)] use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use crate::counter::Counter; use crate::errors::LimitadorError; @@ -226,6 +227,49 @@ pub struct CheckResult { pub limit_name: Option, } +impl CheckResult { + pub fn response_header(&mut self) -> HashMap { + let mut headers = HashMap::new(); + // sort by the limit remaining.. + self.counters.sort_by(|a, b| { + let a_remaining = a.remaining().unwrap_or(a.max_value()); + let b_remaining = b.remaining().unwrap_or(b.max_value()); + a_remaining.cmp(&b_remaining) + }); + + let mut all_limits_text = String::with_capacity(20 * self.counters.len()); + self.counters.iter_mut().for_each(|counter| { + all_limits_text.push_str( + format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(), + ); + if let Some(name) = counter.limit().name() { + all_limits_text.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str()); + } + }); + + if let Some(counter) = self.counters.first() { + headers.insert( + "X-RateLimit-Limit".to_string(), + format!("{}{}", counter.max_value(), all_limits_text), + ); + + let remaining = counter.remaining().unwrap_or(counter.max_value()); + headers.insert( + "X-RateLimit-Remaining".to_string(), + format!("{}", remaining), + ); + + if let Some(duration) = counter.expires_in() { + headers.insert( + "X-RateLimit-Reset".to_string(), + format!("{}", duration.as_secs()), + ); + } + } + headers + } +} + impl From for bool { fn from(value: CheckResult) -> Self { value.limited @@ -298,7 +342,11 @@ impl RateLimiter { } pub fn get_limits(&self, namespace: &Namespace) -> HashSet { - self.storage.get_limits(namespace) + self.storage + .get_limits(namespace) + .iter() + .map(|l| (**l).clone()) + .collect() } pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> { @@ -432,12 +480,12 @@ impl RateLimiter { namespace: &Namespace, values: &HashMap, ) -> Result, LimitadorError> { - let limits = self.get_limits(namespace); + let limits = self.storage.get_limits(namespace); let counters = limits .iter() .filter(|lim| lim.applies(values)) - .map(|lim| Counter::new(lim.clone(), values.clone())) + .map(|lim| Counter::new(Arc::clone(lim), values.clone())) .collect(); Ok(counters) @@ -470,7 +518,11 @@ impl AsyncRateLimiter { } pub fn get_limits(&self, namespace: &Namespace) -> HashSet { - self.storage.get_limits(namespace) + self.storage + .get_limits(namespace) + .iter() + .map(|l| (**l).clone()) + .collect() } pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> { @@ -610,12 +662,12 @@ impl AsyncRateLimiter { namespace: &Namespace, values: &HashMap, ) -> Result, LimitadorError> { - let limits = self.get_limits(namespace); + let limits = self.storage.get_limits(namespace); let counters = limits .iter() .filter(|lim| lim.applies(values)) - .map(|lim| Counter::new(lim.clone(), values.clone())) + .map(|lim| Counter::new(Arc::clone(lim), values.clone())) .collect(); Ok(counters) diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 1e19c2c6..148af984 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -11,6 +11,8 @@ use rocksdb::{ DB, }; use std::collections::{BTreeSet, HashSet}; +use std::ops::Deref; +use std::sync::Arc; use std::time::{Duration, SystemTime}; use tracing::trace_span; @@ -91,7 +93,7 @@ impl CounterStorage for RocksDbStorage { } #[tracing::instrument(skip_all)] - fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { + fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut counters = HashSet::default(); let namepaces: BTreeSet<&str> = limits.iter().map(|l| l.namespace().as_ref()).collect(); for ns in namepaces { @@ -113,8 +115,8 @@ impl CounterStorage for RocksDbStorage { } let value: ExpiringValue = value.as_ref().try_into()?; for limit in limits { - if limit == counter.limit() { - counter.update_to_limit(limit); + if limit.deref() == counter.limit() { + counter.update_to_limit(Arc::clone(limit)); let ttl = value.ttl(); counter.set_expires_in(ttl); counter.set_remaining(limit.max_value() - value.value()); @@ -133,8 +135,8 @@ impl CounterStorage for RocksDbStorage { } #[tracing::instrument(skip_all)] - fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { - let counters = self.get_counters(&limits)?; + fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { + let counters = self.get_counters(limits)?; for counter in &counters { let span = trace_span!("datastore"); let _entered = span.enter(); diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 428f5178..6deda533 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -170,10 +170,10 @@ impl CounterStorage for CrInMemoryStorage { } #[tracing::instrument(skip_all)] - fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { + fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut res = HashSet::new(); - let limits: HashSet<_> = limits.iter().map(encode_limit_to_key).collect(); + let limits: HashSet<_> = limits.iter().map(|l| encode_limit_to_key(l)).collect(); let limits_map = self.limits.read().unwrap(); for (key, counter_value) in limits_map.iter() { @@ -200,9 +200,9 @@ impl CounterStorage for CrInMemoryStorage { } #[tracing::instrument(skip_all)] - fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { for limit in limits { - self.delete_counters_of_limit(&limit); + self.delete_counters_of_limit(limit); } Ok(()) } diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index f32e2a22..2fedb7e8 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -168,10 +168,10 @@ impl CounterStorage for InMemoryStorage { } #[tracing::instrument(skip_all)] - fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { + fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut res = HashSet::new(); - let namespaces: HashSet<&Namespace> = limits.iter().map(Limit::namespace).collect(); + let namespaces: HashSet<&Namespace> = limits.iter().map(|l| l.namespace()).collect(); let limits_by_namespace = self.limits_for_namespace.read().unwrap(); for namespace in namespaces { @@ -209,9 +209,9 @@ impl CounterStorage for InMemoryStorage { } #[tracing::instrument(skip_all)] - fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { for limit in limits { - self.delete_counters_of_limit(&limit); + self.delete_counters_of_limit(limit); } Ok(()) } diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 6d32977c..81d818c6 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -14,6 +14,7 @@ use crate::counter::Counter; use crate::limit::Limit; +use std::sync::Arc; pub fn key_for_counter(counter: &Counter) -> String { if counter.remaining().is_some() || counter.expires_in().is_some() { @@ -43,9 +44,9 @@ pub fn prefix_for_namespace(namespace: &str) -> String { format!("namespace:{{{namespace}}},") } -pub fn counter_from_counter_key(key: &str, limit: &Limit) -> Counter { +pub fn counter_from_counter_key(key: &str, limit: Arc) -> Counter { let mut counter = partial_counter_from_counter_key(key); - if !counter.update_to_limit(limit) { + if !counter.update_to_limit(Arc::clone(&limit)) { // this means some kind of data corruption _or_ most probably // an out of sync `impl PartialEq for Limit` vs `pub fn key_for_counter(counter: &Counter) -> String` panic!( diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index 22abd33a..403d21a6 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -3,7 +3,7 @@ use crate::limit::{Limit, Namespace}; use crate::InMemoryStorage; use async_trait::async_trait; use std::collections::{HashMap, HashSet}; -use std::sync::RwLock; +use std::sync::{Arc, RwLock}; use thiserror::Error; #[cfg(feature = "disk_storage")] @@ -28,12 +28,12 @@ pub enum Authorization { } pub struct Storage { - limits: RwLock>>, + limits: RwLock>>>, counters: Box, } pub struct AsyncStorage { - limits: RwLock>>, + limits: RwLock>>>, counters: Box, } @@ -60,7 +60,7 @@ impl Storage { let namespace = limit.namespace().clone(); let mut limits = self.limits.write().unwrap(); self.counters.add_counter(&limit).unwrap(); - limits.entry(namespace).or_default().insert(limit) + limits.entry(namespace).or_default().insert(Arc::new(limit)) } pub fn update_limit(&self, update: &Limit) -> bool { @@ -74,24 +74,33 @@ impl Storage { }; if req_update { limits.remove(update); - limits.insert(update.clone()); + limits.insert(Arc::new(update.clone())); return true; } } false } - pub fn get_limits(&self, namespace: &Namespace) -> HashSet { + pub fn get_limits(&self, namespace: &Namespace) -> HashSet> { match self.limits.read().unwrap().get(namespace) { - Some(limits) => limits.clone(), + // todo revise typing here? + Some(limits) => limits.iter().map(Arc::clone).collect(), None => HashSet::new(), } } pub fn delete_limit(&self, limit: &Limit) -> Result<(), StorageErr> { + let arc = match self.limits.read().unwrap().get(limit.namespace()) { + None => Arc::new(limit.clone()), + Some(limits) => limits + .iter() + .find(|l| ***l == *limit) + .cloned() + .unwrap_or_else(|| Arc::new(limit.clone())), + }; let mut limits = HashSet::new(); - limits.insert(limit.clone()); - self.counters.delete_counters(limits)?; + limits.insert(arc); + self.counters.delete_counters(&limits)?; let mut limits = self.limits.write().unwrap(); @@ -107,7 +116,7 @@ impl Storage { pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), StorageErr> { if let Some(data) = self.limits.write().unwrap().remove(namespace) { - self.counters.delete_counters(data)?; + self.counters.delete_counters(&data)?; } Ok(()) } @@ -161,10 +170,10 @@ impl AsyncStorage { let mut limits_for_namespace = self.limits.write().unwrap(); match limits_for_namespace.get_mut(&namespace) { - Some(limits) => limits.insert(limit), + Some(limits) => limits.insert(Arc::new(limit)), None => { let mut limits = HashSet::new(); - limits.insert(limit); + limits.insert(Arc::new(limit)); limits_for_namespace.insert(namespace, limits); true } @@ -182,24 +191,32 @@ impl AsyncStorage { }; if req_update { limits.remove(update); - limits.insert(update.clone()); + limits.insert(Arc::new(update.clone())); return true; } } false } - pub fn get_limits(&self, namespace: &Namespace) -> HashSet { + pub fn get_limits(&self, namespace: &Namespace) -> HashSet> { match self.limits.read().unwrap().get(namespace) { - Some(limits) => limits.iter().cloned().collect(), + Some(limits) => limits.iter().map(Arc::clone).collect(), None => HashSet::new(), } } pub async fn delete_limit(&self, limit: &Limit) -> Result<(), StorageErr> { + let arc = match self.limits.read().unwrap().get(limit.namespace()) { + None => Arc::new(limit.clone()), + Some(limits) => limits + .iter() + .find(|l| ***l == *limit) + .cloned() + .unwrap_or_else(|| Arc::new(limit.clone())), + }; let mut limits = HashSet::new(); - limits.insert(limit.clone()); - self.counters.delete_counters(limits).await?; + limits.insert(arc); + self.counters.delete_counters(&limits).await?; let mut limits_for_namespace = self.limits.write().unwrap(); @@ -216,8 +233,7 @@ impl AsyncStorage { pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), StorageErr> { let option = { self.limits.write().unwrap().remove(namespace) }; if let Some(data) = option { - let limits = data.iter().cloned().collect(); - self.counters.delete_counters(limits).await?; + self.counters.delete_counters(&data).await?; } Ok(()) } @@ -250,7 +266,7 @@ impl AsyncStorage { namespace: &Namespace, ) -> Result, StorageErr> { let limits = self.get_limits(namespace); - self.counters.get_counters(limits).await + self.counters.get_counters(&limits).await } pub async fn clear(&self) -> Result<(), StorageErr> { @@ -269,8 +285,8 @@ pub trait CounterStorage: Sync + Send { delta: u64, load_counters: bool, ) -> Result; - fn get_counters(&self, limits: &HashSet) -> Result, StorageErr>; - fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr>; + fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr>; // todo revise typing here? + fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr>; // todo revise typing here? fn clear(&self) -> Result<(), StorageErr>; } @@ -284,8 +300,11 @@ pub trait AsyncCounterStorage: Sync + Send { delta: u64, load_counters: bool, ) -> Result; - async fn get_counters(&self, limits: HashSet) -> Result, StorageErr>; - async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr>; + async fn get_counters( + &self, + limits: &HashSet>, + ) -> Result, StorageErr>; + async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr>; async fn clear(&self) -> Result<(), StorageErr>; } diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index 18175c75..d29e7b3a 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -11,7 +11,9 @@ use crate::storage::{AsyncCounterStorage, Authorization, StorageErr}; use async_trait::async_trait; use redis::{AsyncCommands, RedisError}; use std::collections::HashSet; +use std::ops::Deref; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use tracing::{debug_span, Instrument}; @@ -127,20 +129,24 @@ impl AsyncCounterStorage for AsyncRedisStorage { } #[tracing::instrument(skip_all)] - async fn get_counters(&self, limits: HashSet) -> Result, StorageErr> { + async fn get_counters( + &self, + limits: &HashSet>, + ) -> Result, StorageErr> { let mut res = HashSet::new(); let mut con = self.conn_manager.clone(); for limit in limits { let counter_keys = { - con.smembers::>(key_for_counters_of_limit(&limit)) + con.smembers::>(key_for_counters_of_limit(limit)) .instrument(debug_span!("datastore")) .await? }; for counter_key in counter_keys { - let mut counter: Counter = counter_from_counter_key(&counter_key, &limit); + let mut counter: Counter = + counter_from_counter_key(&counter_key, Arc::clone(limit)); // If the key does not exist, it means that the counter expired, // so we don't have to return it. @@ -172,9 +178,9 @@ impl AsyncCounterStorage for AsyncRedisStorage { } #[tracing::instrument(skip_all)] - async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { for limit in limits { - self.delete_counters_associated_with_limit(&limit) + self.delete_counters_associated_with_limit(limit.deref()) .instrument(debug_span!("datastore")) .await? } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index fbf4aa89..9a3ae681 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -132,12 +132,15 @@ impl AsyncCounterStorage for CachedRedisStorage { } #[tracing::instrument(skip_all)] - async fn get_counters(&self, limits: HashSet) -> Result, StorageErr> { + async fn get_counters( + &self, + limits: &HashSet>, + ) -> Result, StorageErr> { self.async_redis_storage.get_counters(limits).await } #[tracing::instrument(skip_all)] - async fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + async fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { self.async_redis_storage.delete_counters(limits).await } diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index 002e9444..81eb3f11 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -9,6 +9,8 @@ use crate::storage::redis::scripts::{SCRIPT_UPDATE_COUNTER, VALUES_AND_TTLS}; use crate::storage::{Authorization, CounterStorage, StorageErr}; use r2d2::{ManageConnection, Pool}; use std::collections::HashSet; +use std::ops::Deref; +use std::sync::Arc; use std::time::Duration; const DEFAULT_REDIS_URL: &str = "redis://127.0.0.1:6379"; @@ -106,7 +108,7 @@ impl CounterStorage for RedisStorage { } #[tracing::instrument(skip_all)] - fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { + fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut res = HashSet::new(); let mut con = self.conn_pool.get()?; @@ -116,7 +118,8 @@ impl CounterStorage for RedisStorage { con.smembers::>(key_for_counters_of_limit(limit))?; for counter_key in counter_keys { - let mut counter: Counter = counter_from_counter_key(&counter_key, limit); + let mut counter: Counter = + counter_from_counter_key(&counter_key, Arc::clone(limit)); // If the key does not exist, it means that the counter expired, // so we don't have to return it. @@ -143,12 +146,12 @@ impl CounterStorage for RedisStorage { } #[tracing::instrument(skip_all)] - fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + fn delete_counters(&self, limits: &HashSet>) -> Result<(), StorageErr> { let mut con = self.conn_pool.get()?; for limit in limits { let counter_keys = - con.smembers::>(key_for_counters_of_limit(&limit))?; + con.smembers::>(key_for_counters_of_limit(limit.deref()))?; for counter_key in counter_keys { con.del(counter_key)?;