diff --git a/limitador/proto/distributed.proto b/limitador/proto/distributed.proto index ff470931..d6456f92 100644 --- a/limitador/proto/distributed.proto +++ b/limitador/proto/distributed.proto @@ -5,18 +5,20 @@ package limitador.service.distributed.v1; // A packet defines all the types of messages that can be sent between replication peers. message Packet { oneof message { - // the Hello message is used to introduce a peer to another peer. It is the first message sent by a peer. + // the hello message is used to introduce a peer to another peer. It is the first message sent by a peer. Hello hello = 1; - // the MembershipUpdate message is used to gossip about the other peers in the cluster: + // the membership_update message is used to gossip about the other peers in the cluster: // 1) sent after the first Hello message // 2) sent when the membership state changes MembershipUpdate membership_update = 2; - // the Ping message is used to request a pong from the other peer. - Ping ping = 3; - // the Pong message is used to respond to a ping. + // the ping message is used to request a pong from the other peer. + Empty ping = 3; + // the pong message is used to respond to a ping. Pong pong = 4; - // the CounterUpdate message is used to send counter updates. + // the counter_update message is used to send counter updates. CounterUpdate counter_update = 5; + // the re_sync_end message is used to signal that the re-sync process has ended. + Empty re_sync_end = 6; } } @@ -30,8 +32,8 @@ message Hello { optional string receiver_url = 3; } -// A request to a peer to respond with a Pong message. -message Ping {} +// A packet message that does not have any additional data. +message Empty {} // Pong is the response to a Ping and Hello message. message Pong { diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 38220e5f..ebf0216c 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -133,6 +133,17 @@ impl CrCounterValue { (expiry.into_inner(), map) } + pub fn into_ourselves_inner(self) -> (SystemTime, A, u64) { + let Self { + ourselves, + max_value: _, + value, + others: _, + expiry, + } = self; + (expiry.into_inner(), ourselves, value.into_inner()) + } + fn reset(&self, expiry: SystemTime) { let mut guard = self.others.write().unwrap(); self.expiry.update(expiry); diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 79d3d985..1e04fef0 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -5,10 +5,10 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; +use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::Sender; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time::sleep; - use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::debug; @@ -17,7 +17,7 @@ use crate::storage::distributed::grpc::v1::packet::Message; use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient; use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer}; use crate::storage::distributed::grpc::v1::{ - CounterUpdate, Hello, MembershipUpdate, Packet, Peer, Pong, + CounterUpdate, Empty, Hello, MembershipUpdate, Packet, Peer, Pong, }; // clippy will barf on protobuff generated code for enum variants in @@ -84,8 +84,8 @@ struct Session { impl Session { async fn close(&mut self) { - let mut state = self.replication_state.write().await; - if let Some(peer) = state.peer_trackers.get_mut(&self.peer_id) { + let mut replication_state = self.replication_state.write().await; + if let Some(peer) = replication_state.peer_trackers.get_mut(&self.peer_id) { peer.session = None; } } @@ -105,11 +105,50 @@ impl Session { })) .await?; - let mut udpates_to_send = self.broker_state.publisher.subscribe(); - + // start the re-sync process with the peer, start sending him all the local counter values + let (tx, mut rx) = mpsc::channel::>(1); + let peer_id = self.peer_id.clone(); + let out_stream = self.out_stream.clone(); + tokio::spawn(async move { + let mut counter = 0u64; + while let Some(rsync_message) = rx.recv().await { + match rsync_message { + Some(update) => { + counter += 1; + if let Err(err) = out_stream + .clone() + .send(Ok(Message::CounterUpdate(update))) + .await + { + debug!("peer: '{}': ReSyncRequest: send error: {:?}", peer_id, err); + return; + } + } + None => { + debug!( + "peer: '{}': rysnc completed, sent %d updates: {:?}", + peer_id, counter + ); + _ = out_stream + .clone() + .send(Ok(Message::ReSyncEnd(Empty::default()))) + .await; + } + } + } + }); + self.broker_state + .on_re_sync + .try_send(tx) + .map_err(|err| match err { + TrySendError::Full(_) => Status::resource_exhausted("re-sync channel full"), + TrySendError::Closed(_) => Status::unavailable("re-sync channel closed"), + })?; + + let mut updates = self.broker_state.publisher.subscribe(); loop { tokio::select! { - update = udpates_to_send.recv() => { + update = updates.recv() => { let update = update.map_err(|_| Status::unknown("broadcast error"))?; self.send(Message::CounterUpdate(update)).await?; } @@ -123,11 +162,11 @@ impl Session { self.process_packet(packet).await?; }, Some(Err(err)) => { - if is_disconnect(&err) { + return if is_disconnect(&err) { debug!("peer: '{}': disconnected: {:?}", self.peer_id, err); - return Ok(()); + Ok(()) } else { - return Err(err); + Err(err) } }, } @@ -289,13 +328,13 @@ fn is_disconnect(err: &Status) -> bool { // MessageSender is used to abstract the difference between the server and client sender streams... #[derive(Clone)] -enum MessageSender { +pub enum MessageSender { Server(Sender>), Client(Sender), } impl MessageSender { - async fn send(self, message: Result) -> Result<(), Status> { + pub async fn send(self, message: Result) -> Result<(), Status> { match self { MessageSender::Server(sender) => { let value = message.map(|x| Packet { message: Some(x) }); @@ -324,6 +363,7 @@ struct BrokerState { id: String, publisher: broadcast::Sender, on_counter_update: Arc, + on_re_sync: Arc>>>, } #[derive(Clone)] @@ -340,6 +380,7 @@ impl Broker { listen_address: SocketAddr, peer_urls: Vec, on_counter_update: CounterUpdateFn, + on_re_sync: Sender>>, ) -> Broker { let (tx, _) = broadcast::channel(16); let publisher: broadcast::Sender = tx; @@ -351,6 +392,7 @@ impl Broker { id, publisher, on_counter_update: Arc::new(on_counter_update), + on_re_sync: Arc::new(on_re_sync), }, replication_state: Arc::new(RwLock::new(ReplicationState { discovered_urls: HashSet::new(), diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 452b3aa2..e2ca8c42 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -5,6 +5,9 @@ use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; +use tracing::debug; use crate::counter::Counter; use crate::limit::{Limit, Namespace}; @@ -217,6 +220,8 @@ impl CrInMemoryStorage { let limits = Arc::new(RwLock::new(LimitsMap::new())); let limits_clone = limits.clone(); + + let (re_sync_queue_tx, mut re_sync_queue_rx) = mpsc::channel(100); let broker = grpc::Broker::new( identifier.clone(), listen_address, @@ -232,6 +237,7 @@ impl CrInMemoryStorage { let value = limits.get(&update.key).unwrap(); value.merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); }), + re_sync_queue_tx, ); { @@ -241,6 +247,17 @@ impl CrInMemoryStorage { }); } + // process the re-sync requests... + { + let limits = limits.clone(); + tokio::spawn(async move { + let limits = limits.clone(); + while let Some(sender) = re_sync_queue_rx.recv().await { + process_re_sync(&limits, sender).await; + } + }); + } + Self { identifier, limits, @@ -279,6 +296,49 @@ impl CrInMemoryStorage { } } +async fn process_re_sync( + limits: &Arc, CrCounterValue>>>, + sender: Sender>, +) { + // sending all the counters to the peer might take a while, so we don't want to lock + // the limits map for too long, lets figure first get the list of keys that needs to be sent. + let keys: Vec<_> = { + let limits = limits.read().unwrap(); + limits.keys().cloned().collect() + }; + + for key in keys { + let update = { + let limits = limits.read().unwrap(); + limits.get(&key).and_then(|store_value| { + let (expiry, ourself, value) = store_value.clone().into_ourselves_inner(); + if value == 0 { + None // no point in sending a counter that is empty + } else { + let values = HashMap::from([(ourself, value)]); + Some(CounterUpdate { + key: key.clone(), + values, + expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), + }) + } + }) + }; + // skip None, it means the counter was deleted. + if let Some(update) = update { + match sender.send(Some(update)).await { + Ok(_) => {} + Err(err) => { + debug!("Failed to send re-sync counter update to peer: {:?}", err); + break; + } + } + } + } + // signal the end of the re-sync + _ = sender.send(None).await; +} + #[derive(Clone, Debug, Serialize, Deserialize)] struct CounterKey { namespace: Namespace, diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 7e0e6595..da20237a 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -1184,8 +1184,6 @@ mod test { Fut: Future>, { let rate_limiters = create_distributed_limiters(2).await; - tokio::time::sleep(Duration::from_secs(1)).await; - let namespace = "test_namespace"; let max_hits = 3; let limit = Limit::new(