Skip to content

Commit

Permalink
Use new ntex-io apis
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Nov 10, 2023
1 parent a8c37e6 commit ceb70ad
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 181 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.12.8] - 2023-11-11

* Use new ntex-io apis

## [0.12.7] - 2023-11-04

* Fix v5::Subscribe/Unsubscribe packet properties encoding
Expand Down
11 changes: 7 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.12.7"
version = "0.12.8"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand All @@ -11,9 +11,12 @@ license = "MIT"
exclude = [".gitignore", ".travis.yml", ".cargo/config"]
edition = "2021"

[package.metadata.docs.rs]
features = ["ntex/tokio"]

[dependencies]
ntex = "0.7.4"
ntex-util = "0.3.2"
ntex = "0.7.9"
ntex-bytes = "0.1.21"
bitflags = "2.4"
log = "0.4"
pin-project-lite = "0.2"
Expand All @@ -28,4 +31,4 @@ rustls = "0.21"
rustls-pemfile = "1.0"
openssl = "0.10"
test-case = "3.2"
ntex = { version = "0.7.4", features = ["tokio", "rustls", "openssl"] }
ntex = { version = "0.7", features = ["tokio", "rustls", "openssl"] }
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ pub enum ProtocolError {
/// Keep alive timeout
#[error("Keep Alive timeout")]
KeepAliveTimeout,
/// Read frame timeout
#[error("Read frame timeout")]
ReadTimeout,
}

#[derive(Debug, thiserror::Error)]
Expand Down
137 changes: 86 additions & 51 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,26 @@ use std::task::{Context, Poll};
use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc, time};

use ntex::codec::{Decoder, Encoder};
use ntex::io::{DispatchItem, IoBoxed, IoRef, IoStatusUpdate, RecvError};
use ntex::io::{
Decoded, DispatchItem, DispatcherConfig, IoBoxed, IoRef, IoStatusUpdate, RecvError,
};
use ntex::service::{IntoService, Pipeline, PipelineCall, Service};
use ntex::time::Seconds;
use ntex::time::{now, Seconds};
use ntex::util::{ready, Pool};

type Response<U> = <U as Encoder>::Item;

const ONE_SEC: time::Duration = time::Duration::from_secs(1);

pin_project_lite::pin_project! {
/// Dispatcher for mqtt protocol
pub(crate) struct Dispatcher<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>>,
S: 'static,
U: Encoder,
U: Decoder,
U: 'static,
U: Decoder,
U: 'static,
{
codec: U,
service: Pipeline<S>,
Expand All @@ -31,9 +35,11 @@ pin_project_lite::pin_project! {
}

bitflags::bitflags! {
#[derive(Copy, Clone, Eq, PartialEq)]
struct Flags: u8 {
const READY_ERR = 0b0001;
const IO_ERR = 0b0010;
const READY_ERR = 0b0001;
const IO_ERR = 0b0010;
const READ_TIMEOUT = 0b0100;
}
}

Expand All @@ -42,6 +48,9 @@ struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder> {
flags: Flags,
st: IoDispatcherState,
state: Rc<RefCell<DispatcherState<S, U>>>,
config: DispatcherConfig,
read_bytes: u32,
read_max_timeout: time::Instant,
keepalive_timeout: time::Duration,
}

Expand Down Expand Up @@ -102,11 +111,11 @@ where
io: IoBoxed,
codec: U,
service: F,
config: &DispatcherConfig,
) -> Self {
let keepalive_timeout = Seconds(30).into();

// register keepalive timer
io.start_keepalive_timer(keepalive_timeout);
io.start_timer(Seconds(30).into());
io.set_disconnect_timeout(config.disconnect_timeout());

let state = Rc::new(RefCell::new(DispatcherState {
error: None,
Expand All @@ -124,9 +133,12 @@ where
inner: DispatcherInner {
io,
state,
keepalive_timeout,
config: config.clone(),
flags: Flags::empty(),
st: IoDispatcherState::Processing,
read_bytes: 0,
read_max_timeout: now(),
keepalive_timeout: time::Duration::from_secs(30),
},
}
}
Expand All @@ -137,44 +149,12 @@ where
///
/// By default keep-alive timeout is set to 30 seconds.
pub(crate) fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
let timeout = timeout.into();
self.inner.io.start_keepalive_timer(timeout);
self.inner.keepalive_timeout = timeout;
self
}

/// Set connection disconnect timeout.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 1 seconds.
pub(crate) fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.io.set_disconnect_timeout(val.into());
self.inner.io.start_timer(timeout.into());
self.inner.keepalive_timeout = timeout.into();
self
}
}

impl<S, U> DispatcherInner<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + Clone + 'static,
<U as Encoder>::Item: 'static,
{
fn update_keepalive(&self) {
// update keep-alive timer
self.io.start_keepalive_timer(self.keepalive_timeout);
}

fn unregister_keepalive(&mut self) {
// unregister keep-alive timer
self.io.stop_keepalive_timer();
self.keepalive_timeout = time::Duration::ZERO;
}
}

impl<S, U> DispatcherState<S, U>
where
S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
Expand Down Expand Up @@ -275,11 +255,15 @@ where
let item = match ready!(inner.poll_service(this.service, cx)) {
PollService::Ready => {
// decode incoming bytes stream
match ready!(inner.io.poll_recv(this.codec, cx)) {
Ok(el) => {
match inner.io.poll_recv_decode(this.codec, cx) {
Ok(decoded) => {
// update keep-alive timer
inner.update_keepalive();
Some(DispatchItem::Item(el))
inner.update_timer(&decoded);
if let Some(el) = decoded.item {
Some(DispatchItem::Item(el))
} else {
return Poll::Pending;
}
}
Err(RecvError::Stop) => {
log::trace!("dispatcher is instructed to stop");
Expand Down Expand Up @@ -377,7 +361,7 @@ where
}
// drain service responses and shutdown io
IoDispatcherState::Stop => {
inner.unregister_keepalive();
inner.io.stop_timer();

// service may relay on poll_ready for response results
if !inner.flags.contains(Flags::READY_ERR) {
Expand Down Expand Up @@ -502,6 +486,52 @@ where
}
}
}

fn update_timer(&mut self, decoded: &Decoded<<U as Decoder>::Item>) {
// we got parsed frame
if decoded.item.is_some() {
// remove all timers
if self.flags.contains(Flags::READ_TIMEOUT) {
self.flags.remove(Flags::READ_TIMEOUT);
}
self.io.stop_timer();
} else if self.flags.contains(Flags::READ_TIMEOUT) {
// update read timer
if let Some((_, max, rate)) = self.config.frame_read_rate() {
let bytes = decoded.remains as u32;

let delta = (bytes - self.read_bytes).try_into().unwrap_or(u16::MAX);

if delta >= rate {
let n = now();
let next = self.io.timer_deadline() + ONE_SEC;
let new_timeout = if n >= next { ONE_SEC } else { next - n };

// max timeout
if max.is_zero() || (n + new_timeout) <= self.read_max_timeout {
self.read_bytes = bytes;
self.io.stop_timer();
self.io.start_timer(new_timeout);
}
}
}
} else {
// no new data then start keep-alive timer
if decoded.remains == 0 {
self.io.start_timer(self.keepalive_timeout);
} else if let Some((period, max, _)) = self.config.frame_read_rate() {
// we got new data but not enough to parse single frame
// start read timer
self.flags.insert(Flags::READ_TIMEOUT);

self.io.start_timer(period);
self.read_bytes = decoded.remains as u32;
if !max.is_zero() {
self.read_max_timeout = now() + max;
}
}
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -529,9 +559,11 @@ mod tests {
service: F,
) -> (Self, nio::IoRef) {
let keepalive_timeout = Seconds(30).into();
io.start_keepalive_timer(keepalive_timeout);
io.start_timer(keepalive_timeout);
let rio = io.get_ref();

let config = DispatcherConfig::default();

let state = Rc::new(RefCell::new(DispatcherState {
error: None,
base: 0,
Expand All @@ -547,10 +579,13 @@ mod tests {
pool: io.memory_pool().pool(),
inner: DispatcherInner {
state,
config,
keepalive_timeout,
io: IoBoxed::from(io),
st: IoDispatcherState::Processing,
flags: Flags::empty(),
read_bytes: 0,
read_max_timeout: now(),
},
},
rio,
Expand Down Expand Up @@ -652,7 +687,7 @@ mod tests {
}),
);
ntex::rt::spawn(async move {
let _ = disp.disconnect_timeout(Seconds(1)).await;
let _ = disp.await;
});

let buf = client.read().await.unwrap();
Expand Down
25 changes: 8 additions & 17 deletions src/service.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{fmt, marker::PhantomData, rc::Rc};

use ntex::codec::{Decoder, Encoder};
use ntex::io::{DispatchItem, Filter, Io, IoBoxed};
use ntex::io::{DispatchItem, DispatcherConfig, Filter, Io, IoBoxed};
use ntex::service::{Service, ServiceCtx, ServiceFactory};
use ntex::time::{Deadline, Seconds};
use ntex::util::{select, BoxFuture, Either};
Expand All @@ -13,13 +13,13 @@ type ResponseItem<U> = Option<<U as Encoder>::Item>;
pub struct MqttServer<St, C, T, Codec> {
connect: C,
handler: Rc<T>,
disconnect_timeout: Seconds,
config: DispatcherConfig,
_t: PhantomData<(St, Codec)>,
}

impl<St, C, T, Codec> MqttServer<St, C, T, Codec> {
pub(crate) fn new(connect: C, service: T, disconnect_timeout: Seconds) -> Self {
MqttServer { connect, disconnect_timeout, handler: Rc::new(service), _t: PhantomData }
pub(crate) fn new(connect: C, service: T, config: DispatcherConfig) -> Self {
MqttServer { connect, config, handler: Rc::new(service), _t: PhantomData }
}
}

Expand All @@ -32,8 +32,8 @@ where
) -> Result<MqttHandler<St, C::Service, T, Codec>, C::InitError> {
// create connect service and then create service impl
Ok(MqttHandler {
config: self.config.clone(),
handler: self.handler.clone(),
disconnect_timeout: self.disconnect_timeout,
connect: self.connect.create(()).await?,
_t: PhantomData,
})
Expand Down Expand Up @@ -119,7 +119,7 @@ where
pub struct MqttHandler<St, C, T, Codec> {
connect: C,
handler: Rc<T>,
disconnect_timeout: Seconds,
config: DispatcherConfig,
_t: PhantomData<(St, Codec)>,
}

Expand Down Expand Up @@ -147,7 +147,6 @@ where
#[inline]
fn call<'a>(&'a self, req: IoBoxed, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> {
Box::pin(async move {
let timeout = self.disconnect_timeout;
let handshake = ctx.call(&self.connect, req).await;

let (io, codec, session, keepalive) = handshake.map_err(|e| {
Expand All @@ -159,10 +158,7 @@ where
let handler = self.handler.create(session).await?;
log::trace!("Connection handler is created, starting dispatcher");

Dispatcher::new(io, codec, handler)
.keepalive_timeout(keepalive)
.disconnect_timeout(timeout)
.await
Dispatcher::new(io, codec, handler, &self.config).keepalive_timeout(keepalive).await
})
}
}
Expand Down Expand Up @@ -223,8 +219,6 @@ where
ctx: ServiceCtx<'a, Self>,
) -> Self::Future<'a> {
Box::pin(async move {
let timeout = self.disconnect_timeout;

let (io, codec, ka, handler) = {
let res = select(
delay,
Expand Down Expand Up @@ -253,10 +247,7 @@ where
}
};

Dispatcher::new(io, codec, handler)
.keepalive_timeout(ka)
.disconnect_timeout(timeout)
.await
Dispatcher::new(io, codec, handler, &self.config).keepalive_timeout(ka).await
})
}
}
Loading

0 comments on commit ceb70ad

Please sign in to comment.