Skip to content

Commit

Permalink
Move the max_value field from the CounterKey to CrCounterValue.
Browse files Browse the repository at this point in the history
Signed-off-by: Hiram Chirino <hiram@hiramchirino.com>
  • Loading branch information
chirino committed May 23, 2024
1 parent 3895c8a commit 1571819
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 58 deletions.
53 changes: 32 additions & 21 deletions limitador/src/storage/distributed/cr_counter_value.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
use crate::storage::atomic_expiring_value::AtomicExpiryTime;
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::time::{Duration, SystemTime};

use crate::storage::atomic_expiring_value::AtomicExpiryTime;

#[derive(Debug)]
pub struct CrCounterValue<A: Ord> {
ourselves: A,
max_value: u64,
value: AtomicU64,
others: RwLock<BTreeMap<A, u64>>,
expiry: AtomicExpiryTime,
}

#[allow(dead_code)]
impl<A: Ord> CrCounterValue<A> {
pub fn new(actor: A, time_window: Duration) -> Self {
pub fn new(actor: A, max_value: u64, time_window: Duration) -> Self {
Self {
ourselves: actor,
max_value,
value: Default::default(),
others: RwLock::default(),
expiry: AtomicExpiryTime::new(SystemTime::now() + time_window),
}
}

pub fn max_value(&self) -> u64 {
self.max_value
}

pub fn read(&self) -> u64 {
self.read_at(SystemTime::now())
}
Expand Down Expand Up @@ -116,6 +123,7 @@ impl<A: Ord> CrCounterValue<A> {
pub fn into_inner(self) -> (SystemTime, BTreeMap<A, u64>) {
let Self {
ourselves,
max_value: _,
value,
others,
expiry,
Expand All @@ -137,6 +145,7 @@ impl<A: Clone + Ord> Clone for CrCounterValue<A> {
fn clone(&self) -> Self {
Self {
ourselves: self.ourselves.clone(),
max_value: self.max_value,
value: AtomicU64::new(self.value.load(Ordering::SeqCst)),
others: RwLock::new(self.others.read().unwrap().clone()),
expiry: self.expiry.clone(),
Expand All @@ -148,6 +157,7 @@ impl<A: Clone + Ord + Default> From<(SystemTime, BTreeMap<A, u64>)> for CrCounte
fn from(value: (SystemTime, BTreeMap<A, u64>)) -> Self {
Self {
ourselves: A::default(),
max_value: 0,
value: Default::default(),
others: RwLock::new(value.1),
expiry: value.0.into(),
Expand All @@ -157,13 +167,14 @@ impl<A: Clone + Ord + Default> From<(SystemTime, BTreeMap<A, u64>)> for CrCounte

#[cfg(test)]
mod tests {
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use std::time::{Duration, SystemTime};

use crate::storage::distributed::cr_counter_value::CrCounterValue;

#[test]
fn local_increments_are_readable() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
a.inc(3, window);
assert_eq!(3, a.read());
a.inc(2, window);
Expand All @@ -173,7 +184,7 @@ mod tests {
#[test]
fn local_increments_expire() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let now = SystemTime::now();
a.inc_at(3, window, now);
assert_eq!(3, a.read());
Expand All @@ -184,7 +195,7 @@ mod tests {
#[test]
fn other_increments_are_readable() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
a.inc_actor('B', 3, window);
assert_eq!(3, a.read());
a.inc_actor('B', 2, window);
Expand All @@ -194,7 +205,7 @@ mod tests {
#[test]
fn other_increments_expire() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let now = SystemTime::now();
a.inc_actor_at('B', 3, window, now);
assert_eq!(3, a.read());
Expand All @@ -205,8 +216,8 @@ mod tests {
#[test]
fn merges() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
a.merge(b);
Expand All @@ -216,8 +227,8 @@ mod tests {
#[test]
fn merges_symetric() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
b.merge(a);
Expand All @@ -227,8 +238,8 @@ mod tests {
#[test]
fn merges_overrides_with_larger_value() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
b.inc_actor('A', 2, window); // older value!
Expand All @@ -239,8 +250,8 @@ mod tests {
#[test]
fn merges_ignore_lesser_values() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
b.inc_actor('A', 5, window); // newer value!
Expand All @@ -251,9 +262,9 @@ mod tests {
#[test]
fn merge_ignores_expired_sets() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', Duration::ZERO);
let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO);
a.inc(3, Duration::ZERO);
let b = CrCounterValue::new('B', window);
let b = CrCounterValue::new('B', u64::MAX, window);
b.inc(2, window);
b.merge(a);
assert_eq!(b.read(), 2);
Expand All @@ -262,9 +273,9 @@ mod tests {
#[test]
fn merge_ignores_expired_sets_symmetric() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', Duration::ZERO);
let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO);
a.inc(3, Duration::ZERO);
let b = CrCounterValue::new('B', window);
let b = CrCounterValue::new('B', u64::MAX, window);
b.inc(2, window);
a.merge(b);
assert_eq!(a.read(), 2);
Expand All @@ -273,9 +284,9 @@ mod tests {
#[test]
fn merge_uses_earliest_expiry() {
let later = Duration::from_secs(1);
let a = CrCounterValue::new('A', later);
let a = CrCounterValue::new('A', u64::MAX, later);
let sooner = Duration::from_millis(200);
let b = CrCounterValue::new('B', sooner);
let b = CrCounterValue::new('B', u64::MAX, sooner);
a.inc(3, later);
b.inc(2, later);
a.merge(b);
Expand Down
94 changes: 57 additions & 37 deletions limitador/src/storage/distributed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl CounterStorage for CrInMemoryStorage {
let key = encode_limit_to_key(limit);
limits.entry(key).or_insert(CrCounterValue::new(
self.identifier.clone(),
limit.max_value(),
Duration::from_secs(limit.seconds()),
));
}
Expand All @@ -59,7 +60,8 @@ impl CounterStorage for CrInMemoryStorage {
match limits.entry(key.clone()) {
Entry::Vacant(entry) => {
let duration = counter.window();
let store_value = CrCounterValue::new(self.identifier.clone(), duration);
let store_value =
CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration);
self.increment_counter(counter, key, &store_value, delta, now);
entry.insert(store_value);
}
Expand Down Expand Up @@ -129,6 +131,7 @@ impl CounterStorage for CrInMemoryStorage {
let mut limits = self.limits.write().unwrap();
let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
));

Expand Down Expand Up @@ -161,10 +164,22 @@ impl CounterStorage for CrInMemoryStorage {
fn get_counters(&self, limits: &HashSet<Limit>) -> Result<HashSet<Counter>, StorageErr> {
let mut res = HashSet::new();

let limits: HashSet<_> = limits.iter().map(encode_limit_to_key).collect();

let limits_map = self.limits.read().unwrap();
for (key, counter_value) in limits_map.iter() {
let mut counter: Counter = decode_counter_key(key).unwrap().into();
if limits.contains(counter.limit()) {
let counter_key = decode_counter_key(key).unwrap();
let limit_key = if !counter_key.vars.is_empty() {
let mut cloned = counter_key.clone();
cloned.vars = HashMap::default();
cloned.encode()
} else {
key.clone()
};

if limits.contains(&limit_key) {
let counter = (&counter_key, counter_value);
let mut counter: Counter = counter.into();
counter.set_remaining(counter.max_value() - counter_value.read());
counter.set_expires_in(counter_value.ttl());
if counter.expires_in().unwrap() > Duration::ZERO {
Expand Down Expand Up @@ -264,56 +279,61 @@ impl CrInMemoryStorage {
}
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
struct CounterKey {
namespace: Namespace,
seconds: u64,
max_value: u64,
conditions: HashSet<String>,
variables: HashSet<String>,
vars: HashMap<String, String>,
}

impl CounterKey {
fn new(limit: &Limit, vars: HashMap<String, String>) -> Self {
CounterKey {
namespace: limit.namespace().clone(),
seconds: limit.seconds(),
variables: limit.variables().clone(),
conditions: limit.conditions().clone(),
vars,
}
}

fn encode(&self) -> Vec<u8> {
postcard::to_stdvec(self).unwrap()
}
}

impl From<(&CounterKey, &CrCounterValue<String>)> for Counter {
fn from(value: (&CounterKey, &CrCounterValue<String>)) -> Self {
let (counter_key, store_value) = value;
let max_value = store_value.max_value();
let mut counter = Self::new(
Limit::new(
counter_key.namespace.clone(),
max_value,
counter_key.seconds,
counter_key.conditions.clone(),
counter_key.vars.keys(),
),
counter_key.vars.clone(),
);
counter.set_remaining(max_value - store_value.read());
counter.set_expires_in(store_value.ttl());
counter
}
}

fn encode_counter_to_key(counter: &Counter) -> Vec<u8> {
let limit = counter.limit();
let key = CounterKey {
namespace: limit.namespace().clone(),
max_value: limit.max_value(),
seconds: limit.seconds(),
variables: limit.variables().clone(),
conditions: limit.conditions().clone(),
vars: counter.set_variables().clone(),
};
let key = CounterKey::new(counter.limit(), counter.set_variables().clone());
postcard::to_stdvec(&key).unwrap()
}

fn encode_limit_to_key(limit: &Limit) -> Vec<u8> {
let key = CounterKey {
namespace: limit.namespace().clone(),
max_value: limit.max_value(),
seconds: limit.seconds(),
variables: limit.variables().clone(),
conditions: limit.conditions().clone(),
vars: HashMap::default(),
};
let key = CounterKey::new(limit, HashMap::default());
postcard::to_stdvec(&key).unwrap()
}

fn decode_counter_key(key: &Vec<u8>) -> postcard::Result<CounterKey> {
postcard::from_bytes(key.as_slice())
}

impl From<CounterKey> for Counter {
fn from(value: CounterKey) -> Self {
Self::new(
Limit::new(
value.namespace,
value.max_value,
value.seconds,
value.conditions,
value.vars.keys(),
),
value.vars,
)
}
}

0 comments on commit 1571819

Please sign in to comment.