diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index e76b4b05..38220e5f 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -1,13 +1,15 @@ -use crate::storage::atomic_expiring_value::AtomicExpiryTime; use std::collections::btree_map::Entry; use std::collections::BTreeMap; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::{Duration, SystemTime}; +use crate::storage::atomic_expiring_value::AtomicExpiryTime; + #[derive(Debug)] pub struct CrCounterValue { ourselves: A, + max_value: u64, value: AtomicU64, others: RwLock>, expiry: AtomicExpiryTime, @@ -15,15 +17,20 @@ pub struct CrCounterValue { #[allow(dead_code)] impl CrCounterValue { - pub fn new(actor: A, time_window: Duration) -> Self { + pub fn new(actor: A, max_value: u64, time_window: Duration) -> Self { Self { ourselves: actor, + max_value, value: Default::default(), others: RwLock::default(), expiry: AtomicExpiryTime::new(SystemTime::now() + time_window), } } + pub fn max_value(&self) -> u64 { + self.max_value + } + pub fn read(&self) -> u64 { self.read_at(SystemTime::now()) } @@ -116,6 +123,7 @@ impl CrCounterValue { pub fn into_inner(self) -> (SystemTime, BTreeMap) { let Self { ourselves, + max_value: _, value, others, expiry, @@ -137,6 +145,7 @@ impl Clone for CrCounterValue { fn clone(&self) -> Self { Self { ourselves: self.ourselves.clone(), + max_value: self.max_value, value: AtomicU64::new(self.value.load(Ordering::SeqCst)), others: RwLock::new(self.others.read().unwrap().clone()), expiry: self.expiry.clone(), @@ -148,6 +157,7 @@ impl From<(SystemTime, BTreeMap)> for CrCounte fn from(value: (SystemTime, BTreeMap)) -> Self { Self { ourselves: A::default(), + max_value: 0, value: Default::default(), others: RwLock::new(value.1), expiry: value.0.into(), @@ -157,13 +167,14 @@ impl From<(SystemTime, BTreeMap)> for CrCounte #[cfg(test)] mod tests { - use crate::storage::distributed::cr_counter_value::CrCounterValue; use std::time::{Duration, SystemTime}; + use crate::storage::distributed::cr_counter_value::CrCounterValue; + #[test] fn local_increments_are_readable() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); a.inc(3, window); assert_eq!(3, a.read()); a.inc(2, window); @@ -173,7 +184,7 @@ mod tests { #[test] fn local_increments_expire() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); let now = SystemTime::now(); a.inc_at(3, window, now); assert_eq!(3, a.read()); @@ -184,7 +195,7 @@ mod tests { #[test] fn other_increments_are_readable() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); a.inc_actor('B', 3, window); assert_eq!(3, a.read()); a.inc_actor('B', 2, window); @@ -194,7 +205,7 @@ mod tests { #[test] fn other_increments_expire() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); + let a = CrCounterValue::new('A', u64::MAX, window); let now = SystemTime::now(); a.inc_actor_at('B', 3, window, now); assert_eq!(3, a.read()); @@ -205,8 +216,8 @@ mod tests { #[test] fn merges() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); a.merge(b); @@ -216,8 +227,8 @@ mod tests { #[test] fn merges_symetric() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); b.merge(a); @@ -227,8 +238,8 @@ mod tests { #[test] fn merges_overrides_with_larger_value() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); b.inc_actor('A', 2, window); // older value! @@ -239,8 +250,8 @@ mod tests { #[test] fn merges_ignore_lesser_values() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); + let a = CrCounterValue::new('A', u64::MAX, window); + let b = CrCounterValue::new('B', u64::MAX, window); a.inc(3, window); b.inc(2, window); b.inc_actor('A', 5, window); // newer value! @@ -251,9 +262,9 @@ mod tests { #[test] fn merge_ignores_expired_sets() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', Duration::ZERO); + let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO); a.inc(3, Duration::ZERO); - let b = CrCounterValue::new('B', window); + let b = CrCounterValue::new('B', u64::MAX, window); b.inc(2, window); b.merge(a); assert_eq!(b.read(), 2); @@ -262,9 +273,9 @@ mod tests { #[test] fn merge_ignores_expired_sets_symmetric() { let window = Duration::from_secs(1); - let a = CrCounterValue::new('A', Duration::ZERO); + let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO); a.inc(3, Duration::ZERO); - let b = CrCounterValue::new('B', window); + let b = CrCounterValue::new('B', u64::MAX, window); b.inc(2, window); a.merge(b); assert_eq!(a.read(), 2); @@ -273,9 +284,9 @@ mod tests { #[test] fn merge_uses_earliest_expiry() { let later = Duration::from_secs(1); - let a = CrCounterValue::new('A', later); + let a = CrCounterValue::new('A', u64::MAX, later); let sooner = Duration::from_millis(200); - let b = CrCounterValue::new('B', sooner); + let b = CrCounterValue::new('B', u64::MAX, sooner); a.inc(3, later); b.inc(2, later); a.merge(b); diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 65c1394c..452b3aa2 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -44,6 +44,7 @@ impl CounterStorage for CrInMemoryStorage { let key = encode_limit_to_key(limit); limits.entry(key).or_insert(CrCounterValue::new( self.identifier.clone(), + limit.max_value(), Duration::from_secs(limit.seconds()), )); } @@ -59,7 +60,8 @@ impl CounterStorage for CrInMemoryStorage { match limits.entry(key.clone()) { Entry::Vacant(entry) => { let duration = counter.window(); - let store_value = CrCounterValue::new(self.identifier.clone(), duration); + let store_value = + CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration); self.increment_counter(counter, key, &store_value, delta, now); entry.insert(store_value); } @@ -129,6 +131,7 @@ impl CounterStorage for CrInMemoryStorage { let mut limits = self.limits.write().unwrap(); let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new( self.identifier.clone(), + counter.max_value(), counter.window(), )); @@ -161,10 +164,22 @@ impl CounterStorage for CrInMemoryStorage { fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { let mut res = HashSet::new(); + let limits: HashSet<_> = limits.iter().map(encode_limit_to_key).collect(); + let limits_map = self.limits.read().unwrap(); for (key, counter_value) in limits_map.iter() { - let mut counter: Counter = decode_counter_key(key).unwrap().into(); - if limits.contains(counter.limit()) { + let counter_key = decode_counter_key(key).unwrap(); + let limit_key = if !counter_key.vars.is_empty() { + let mut cloned = counter_key.clone(); + cloned.vars = HashMap::default(); + cloned.encode() + } else { + key.clone() + }; + + if limits.contains(&limit_key) { + let counter = (&counter_key, counter_value); + let mut counter: Counter = counter.into(); counter.set_remaining(counter.max_value() - counter_value.read()); counter.set_expires_in(counter_value.ttl()); if counter.expires_in().unwrap() > Duration::ZERO { @@ -264,56 +279,61 @@ impl CrInMemoryStorage { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] struct CounterKey { namespace: Namespace, seconds: u64, - max_value: u64, conditions: HashSet, variables: HashSet, vars: HashMap, } +impl CounterKey { + fn new(limit: &Limit, vars: HashMap) -> Self { + CounterKey { + namespace: limit.namespace().clone(), + seconds: limit.seconds(), + variables: limit.variables().clone(), + conditions: limit.conditions().clone(), + vars, + } + } + + fn encode(&self) -> Vec { + postcard::to_stdvec(self).unwrap() + } +} + +impl From<(&CounterKey, &CrCounterValue)> for Counter { + fn from(value: (&CounterKey, &CrCounterValue)) -> Self { + let (counter_key, store_value) = value; + let max_value = store_value.max_value(); + let mut counter = Self::new( + Limit::new( + counter_key.namespace.clone(), + max_value, + counter_key.seconds, + counter_key.conditions.clone(), + counter_key.vars.keys(), + ), + counter_key.vars.clone(), + ); + counter.set_remaining(max_value - store_value.read()); + counter.set_expires_in(store_value.ttl()); + counter + } +} + fn encode_counter_to_key(counter: &Counter) -> Vec { - let limit = counter.limit(); - let key = CounterKey { - namespace: limit.namespace().clone(), - max_value: limit.max_value(), - seconds: limit.seconds(), - variables: limit.variables().clone(), - conditions: limit.conditions().clone(), - vars: counter.set_variables().clone(), - }; + let key = CounterKey::new(counter.limit(), counter.set_variables().clone()); postcard::to_stdvec(&key).unwrap() } fn encode_limit_to_key(limit: &Limit) -> Vec { - let key = CounterKey { - namespace: limit.namespace().clone(), - max_value: limit.max_value(), - seconds: limit.seconds(), - variables: limit.variables().clone(), - conditions: limit.conditions().clone(), - vars: HashMap::default(), - }; + let key = CounterKey::new(limit, HashMap::default()); postcard::to_stdvec(&key).unwrap() } fn decode_counter_key(key: &Vec) -> postcard::Result { postcard::from_bytes(key.as_slice()) } - -impl From for Counter { - fn from(value: CounterKey) -> Self { - Self::new( - Limit::new( - value.namespace, - value.max_value, - value.seconds, - value.conditions, - value.vars.keys(), - ), - value.vars, - ) - } -}