Skip to content

Commit

Permalink
ensure we reconnect on failure (#173)
Browse files Browse the repository at this point in the history
* ensure we reconnect on failure

* refactor

* fix test
  • Loading branch information
xlc committed May 10, 2024
1 parent 648a3c0 commit e61fa69
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 73 deletions.
91 changes: 63 additions & 28 deletions src/extensions/client/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ pub struct Endpoint {
url: String,
health: Arc<Health>,
client_rx: tokio::sync::watch::Receiver<Option<Arc<Client>>>,
reconnect_tx: tokio::sync::mpsc::Sender<()>,
on_client_ready: Arc<tokio::sync::Notify>,
background_tasks: Vec<tokio::task::JoinHandle<()>>,
connect_counter: Arc<AtomicU32>,
}

impl Drop for Endpoint {
Expand All @@ -35,22 +37,26 @@ impl Endpoint {
url: String,
request_timeout: Option<Duration>,
connection_timeout: Option<Duration>,
health_config: HealthCheckConfig,
health_config: Option<HealthCheckConfig>,
) -> Self {
let (client_tx, client_rx) = tokio::sync::watch::channel(None);
let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(1);
let on_client_ready = Arc::new(tokio::sync::Notify::new());
let health = Arc::new(Health::new(url.clone(), health_config));
let connect_counter = Arc::new(AtomicU32::new(0));

let url_ = url.clone();
let health_ = health.clone();
let on_client_ready_ = on_client_ready.clone();
let connect_counter_ = connect_counter.clone();

// This task will try to connect to the endpoint and keep the connection alive
let connection_task = tokio::spawn(async move {
let connect_backoff_counter = Arc::new(AtomicU32::new(0));

loop {
tracing::info!("Connecting endpoint: {url_}");
connect_counter_.fetch_add(1, Ordering::Relaxed);

let client = WsClientBuilder::default()
.request_timeout(request_timeout.unwrap_or(Duration::from_secs(30)))
Expand All @@ -68,7 +74,15 @@ impl Endpoint {
on_client_ready_.notify_waiters();
tracing::info!("Endpoint connected: {url_}");
connect_backoff_counter.store(0, Ordering::Relaxed);
client.on_disconnect().await;

tokio::select! {
_ = reconnect_rx.recv() => {
tracing::debug!("Endpoint reconnect requested: {url_}");
},
_ = client.on_disconnect() => {
tracing::debug!("Endpoint disconnected: {url_}");
}
}
}
Err(err) => {
health_.on_error(&err);
Expand All @@ -88,8 +102,10 @@ impl Endpoint {
url,
health,
client_rx,
reconnect_tx,
on_client_ready,
background_tasks: vec![connection_task, health_checker],
connect_counter,
}
}

Expand All @@ -108,24 +124,34 @@ impl Endpoint {
self.on_client_ready.notified().await;
}

pub fn connect_counter(&self) -> u32 {
self.connect_counter.load(Ordering::Relaxed)
}

pub async fn request(
&self,
method: &str,
params: Vec<serde_json::Value>,
timeout: Duration,
) -> Result<serde_json::Value, jsonrpsee::core::Error> {
let client = self
.client_rx
.borrow()
.clone()
.ok_or(errors::failed("client not connected"))?;

match tokio::time::timeout(timeout, client.request(method, params.clone())).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(err)) => {
self.health.on_error(&err);
Err(err)
match tokio::time::timeout(timeout, async {
self.connected().await;
let client = self
.client_rx
.borrow()
.clone()
.ok_or(errors::failed("client not connected"))?;
match client.request(method, params.clone()).await {
Ok(resp) => Ok(resp),
Err(err) => {
self.health.on_error(&err);
Err(err)
}
}
})
.await
{
Ok(res) => res,
Err(_) => {
tracing::error!("request timed out method: {method} params: {params:?}");
self.health.on_error(&jsonrpsee::core::Error::RequestTimeout);
Expand All @@ -141,28 +167,37 @@ impl Endpoint {
unsubscribe_method: &str,
timeout: Duration,
) -> Result<Subscription<serde_json::Value>, jsonrpsee::core::Error> {
let client = self
.client_rx
.borrow()
.clone()
.ok_or(errors::failed("client not connected"))?;

match tokio::time::timeout(
timeout,
client.subscribe(subscribe_method, params.clone(), unsubscribe_method),
)
match tokio::time::timeout(timeout, async {
self.connected().await;
let client = self
.client_rx
.borrow()
.clone()
.ok_or(errors::failed("client not connected"))?;
match client
.subscribe(subscribe_method, params.clone(), unsubscribe_method)
.await
{
Ok(resp) => Ok(resp),
Err(err) => {
self.health.on_error(&err);
Err(err)
}
}
})
.await
{
Ok(Ok(response)) => Ok(response),
Ok(Err(err)) => {
self.health.on_error(&err);
Err(err)
}
Ok(res) => res,
Err(_) => {
tracing::error!("subscribe timed out subscribe: {subscribe_method} params: {params:?}");
self.health.on_error(&jsonrpsee::core::Error::RequestTimeout);
Err(jsonrpsee::core::Error::RequestTimeout)
}
}
}

pub async fn reconnect(&self) {
// notify the client to reconnect
self.reconnect_tx.send(()).await.unwrap();
}
}
20 changes: 10 additions & 10 deletions src/extensions/client/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl Event {
#[derive(Debug, Default)]
pub struct Health {
url: String,
config: HealthCheckConfig,
config: Option<HealthCheckConfig>,
score: AtomicU32,
unhealthy: tokio::sync::Notify,
}
Expand All @@ -44,7 +44,7 @@ const MAX_SCORE: u32 = 100;
const THRESHOLD: u32 = MAX_SCORE / 2;

impl Health {
pub fn new(url: String, config: HealthCheckConfig) -> Self {
pub fn new(url: String, config: Option<HealthCheckConfig>) -> Self {
Self {
url,
config,
Expand Down Expand Up @@ -104,18 +104,18 @@ impl Health {
on_client_ready: Arc<tokio::sync::Notify>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
// no health method
if health.config.health_method.is_none() {
return;
}
let config = match health.config {
Some(ref config) => config,
None => return,
};

// Wait for the client to be ready before starting the health check
on_client_ready.notified().await;

let method_name = health.config.health_method.as_ref().expect("checked above");
let health_response = health.config.response.clone();
let interval = Duration::from_secs(health.config.interval_sec);
let healthy_response_time = Duration::from_millis(health.config.healthy_response_time_ms);
let method_name = config.health_method.as_ref().expect("Invalid health config");
let health_response = config.response.clone();
let interval = Duration::from_secs(config.interval_sec);
let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms);

let client = match client_rx_.borrow().clone() {
Some(client) => client,
Expand Down
56 changes: 37 additions & 19 deletions src/extensions/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ impl Client {
retries: Option<u32>,
health_config: Option<HealthCheckConfig>,
) -> Result<Self, anyhow::Error> {
let health_config = health_config.unwrap_or_default();
let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect();

if endpoints.is_empty() {
Expand Down Expand Up @@ -240,34 +239,48 @@ impl Client {

let rotation_notify = Arc::new(Notify::new());
let rotation_notify_bg = rotation_notify.clone();
let endpoints_ = endpoints.clone();
let endpoints2 = endpoints.clone();
let has_health_method = health_config.is_some();

let mut current_endpoint_idx = 0;
let mut selected_endpoint = endpoints[0].clone();

let background_task = tokio::spawn(async move {
let request_backoff_counter = Arc::new(AtomicU32::new(0));

// Select next endpoint with the highest health score, excluding the current one if provided
let healthiest_endpoint = |exclude: Option<Arc<Endpoint>>| async {
// Select next endpoint with the highest health score, excluding the current one if possible
let select_healtiest = |endpoints: Vec<Arc<Endpoint>>, current_idx: usize| async move {
if endpoints.len() == 1 {
let selected_endpoint = endpoints[0].clone();
// Ensure it's connected
selected_endpoint.connected().await;
return selected_endpoint;
return (selected_endpoint, 0);
}

let mut endpoints = endpoints.clone();
// Remove the current endpoint from the list
if let Some(exclude) = exclude {
endpoints.retain(|e| e.url() != exclude.url());
}
// wait for at least one endpoint to connect
futures::future::select_all(endpoints.iter().map(|x| x.connected().boxed())).await;
// Sort by health score
endpoints.sort_by_key(|endpoint| std::cmp::Reverse(endpoint.health().score()));
// Pick the first one
endpoints[0].clone()

let (idx, endpoint) = endpoints
.iter()
.enumerate()
.filter(|(idx, _)| *idx != current_idx)
.max_by_key(|(_, endpoint)| endpoint.health().score())
.expect("No endpoints");
(endpoint.clone(), idx)
};

let mut selected_endpoint = healthiest_endpoint(None).await;
let select_next = |endpoints: Vec<Arc<Endpoint>>, current_idx: usize| async move {
let idx = (current_idx + 1) % endpoints.len();
(endpoints[idx].clone(), idx)
};

let next_endpoint = |current_idx| {
if has_health_method {
select_healtiest(endpoints2.clone(), current_idx).boxed()
} else {
select_next(endpoints2.clone(), current_idx).boxed()
}
};

let handle_message = |message: Message, endpoint: Arc<Endpoint>, rotation_notify: Arc<Notify>| {
let tx = message_tx_bg.clone();
Expand Down Expand Up @@ -422,10 +435,15 @@ impl Client {
_ = selected_endpoint.health().unhealthy() => {
// Current selected endpoint is unhealthy, try to rotate to another one.
// In case of all endpoints are unhealthy, we don't want to keep rotating but stick with the healthiest one.
let new_selected_endpoint = healthiest_endpoint(None).await;
if new_selected_endpoint.url() != selected_endpoint.url() {

// The ws client maybe in a state that requires a reconnect
selected_endpoint.reconnect().await;

let (new_selected_endpoint, new_current_endpoint_idx) = next_endpoint(current_endpoint_idx).await;
if new_current_endpoint_idx != current_endpoint_idx {
tracing::warn!("Switch to endpoint: {new_url}", new_url=new_selected_endpoint.url());
selected_endpoint = new_selected_endpoint;
current_endpoint_idx = new_current_endpoint_idx;
rotation_notify_bg.notify_waiters();
}
}
Expand All @@ -434,7 +452,7 @@ impl Client {
match message {
Some(Message::RotateEndpoint) => {
tracing::info!("Rotating endpoint ...");
selected_endpoint = healthiest_endpoint(Some(selected_endpoint.clone())).await;
(selected_endpoint, current_endpoint_idx) = next_endpoint(current_endpoint_idx).await;
rotation_notify_bg.notify_waiters();
}
Some(message) => handle_message(message, selected_endpoint.clone(), rotation_notify_bg.clone()),
Expand All @@ -449,7 +467,7 @@ impl Client {
});

Ok(Self {
endpoints: endpoints_,
endpoints,
sender: message_tx,
rotation_notify,
retries: retries.unwrap_or(3),
Expand Down
Loading

0 comments on commit e61fa69

Please sign in to comment.