Skip to content

Commit

Permalink
Various cleanups (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Nov 13, 2023
1 parent ceb70ad commit 2ecc983
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 83 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Changes

## [0.12.8] - 2023-11-11
## [0.12.8] - 2023-11-12

* Use new ntex-io apis

Expand Down
6 changes: 3 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub enum HandshakeError<E> {
}

/// Protocol level errors
#[derive(Debug, thiserror::Error)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ProtocolError {
/// MQTT decoding error
#[error("Decoding error: {0:?}")]
Expand All @@ -52,13 +52,13 @@ pub enum ProtocolError {
ReadTimeout,
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
#[error(transparent)]
pub struct ProtocolViolationError {
inner: ViolationInner,
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
enum ViolationInner {
#[error("{message}")]
Common { reason: DisconnectReasonCode, message: &'static str },
Expand Down
114 changes: 51 additions & 63 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ enum IoDispatcherState {
}

pub(crate) enum IoDispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
Expand Down Expand Up @@ -257,106 +256,98 @@ where
// decode incoming bytes stream
match inner.io.poll_recv_decode(this.codec, cx) {
Ok(decoded) => {
// update keep-alive timer
inner.update_timer(&decoded);
if let Some(el) = decoded.item {
Some(DispatchItem::Item(el))
DispatchItem::Item(el)
} else {
return Poll::Pending;
}
}
Err(RecvError::Stop) => {
log::trace!("dispatcher is instructed to stop");
inner.st = IoDispatcherState::Stop;
None
continue;
}
Err(RecvError::KeepAlive) => {
// check keepalive timeout
log::trace!("keepalive timeout");
log::trace!("keep-alive error, stopping dispatcher");
inner.st = IoDispatcherState::Stop;
let mut state = inner.state.borrow_mut();
if state.error.is_none() {
state.error = Some(IoDispatcherError::KeepAlive);
if inner.flags.contains(Flags::READ_TIMEOUT) {
DispatchItem::ReadTimeout
} else {
DispatchItem::KeepAliveTimeout
}
Some(DispatchItem::KeepAliveTimeout)
}
Err(RecvError::WriteBackpressure) => {
if let Err(err) = ready!(inner.io.poll_flush(cx, false)) {
inner.st = IoDispatcherState::Stop;
Some(DispatchItem::Disconnect(Some(err)))
DispatchItem::Disconnect(Some(err))
} else {
continue;
}
}
Err(RecvError::Decoder(err)) => {
inner.st = IoDispatcherState::Stop;
Some(DispatchItem::DecoderError(err))
DispatchItem::DecoderError(err)
}
Err(RecvError::PeerGone(err)) => {
inner.st = IoDispatcherState::Stop;
Some(DispatchItem::Disconnect(err))
DispatchItem::Disconnect(err)
}
}
}
PollService::Item(item) => Some(item),
PollService::Item(item) => item,
PollService::Continue => continue,
};

// call service
if let Some(item) = item {
// optimize first call
if this.response.is_none() {
this.response.set(Some(this.service.call_static(item)));
let res = this.response.as_mut().as_pin_mut().unwrap().poll(cx);

let mut state = inner.state.borrow_mut();

if let Poll::Ready(res) = res {
// check if current result is only response
if state.queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = inner.io.encode(item, this.codec)
{
state.error =
Some(IoDispatcherError::Encoder(err));
}
// optimize first call
if this.response.is_none() {
this.response.set(Some(this.service.call_static(item)));

let res = this.response.as_mut().as_pin_mut().unwrap().poll(cx);
let mut state = inner.state.borrow_mut();

if let Poll::Ready(res) = res {
// check if current result is only response
if state.queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = inner.io.encode(item, this.codec) {
state.error = Some(IoDispatcherError::Encoder(err));
}
Ok(None) => (),
}
} else {
*this.response_idx =
state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Ready(res));
Ok(None) => (),
}
this.response.set(None);
} else {
*this.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);
state.queue.push_back(ServiceResult::Ready(res));
}
this.response.set(None);
} else {
let mut state = inner.state.borrow_mut();
let response_idx = state.base.wrapping_add(state.queue.len());
*this.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);

let st = inner.io.get_ref();
let codec = this.codec.clone();
let state = inner.state.clone();
let fut = this.service.call_static(item);
ntex::rt::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(
item,
response_idx,
&st,
&codec,
true,
);
});
}
} else {
let mut state = inner.state.borrow_mut();
let response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);

let st = inner.io.get_ref();
let codec = this.codec.clone();
let state = inner.state.clone();
let fut = this.service.call_static(item);
ntex::rt::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(
item,
response_idx,
&st,
&codec,
true,
);
});
}
}
// drain service responses and shutdown io
Expand Down Expand Up @@ -443,9 +434,6 @@ where
state.error = Some(IoDispatcherError::Service(err));
PollService::Continue
}
IoDispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
}
} else {
PollService::Ready
Expand Down
16 changes: 7 additions & 9 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cell::RefCell;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll};
use std::{future::Future, marker::PhantomData, num::NonZeroU16, pin::Pin, rc::Rc};

use ntex::io::DispatchItem;
Expand Down Expand Up @@ -386,8 +386,8 @@ where
let mut this = self.as_mut().project();

match this.state.as_mut().project() {
PublishResponseStateProject::Publish { fut } => match fut.poll(cx) {
Poll::Ready(Ok(_)) => {
PublishResponseStateProject::Publish { fut } => match ready!(fut.poll(cx)) {
Ok(_) => {
log::trace!("Publish result for packet {:?} is ready", this.packet_id);

if let Some(packet_id) = this.packet_id {
Expand All @@ -399,7 +399,7 @@ where
Poll::Ready(Ok(None))
}
}
Poll::Ready(Err(e)) => {
Err(e) => {
this.state.set(PublishResponseState::Control {
fut: ControlResponse::new(
ControlMessage::error(e.into()),
Expand All @@ -409,7 +409,6 @@ where
});
self.poll(cx)
}
Poll::Pending => Poll::Pending,
},
PublishResponseStateProject::Control { fut } => fut.poll(cx),
}
Expand Down Expand Up @@ -453,8 +452,8 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().project();

match this.fut.poll(cx) {
Poll::Ready(Ok(item)) => {
match ready!(this.fut.poll(cx)) {
Ok(item) => {
let packet = match item.result {
ControlResultKind::Ping => Some(codec::Packet::PingResponse),
ControlResultKind::Subscribe(res) => {
Expand All @@ -478,7 +477,7 @@ where
};
Poll::Ready(Ok(packet))
}
Poll::Ready(Err(err)) => {
Err(err) => {
// do not handle nested error
if *this.error {
Poll::Ready(Err(err))
Expand All @@ -496,7 +495,6 @@ where
}
}
}
Poll::Pending => Poll::Pending,
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/v3/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ where
self
}

/// Set read rate parameters for single frame.
///
/// Set max timeout for reading single frame. If the client
/// sends `rate` amount of data, increase the timeout by 1 second for every.
/// But no more than `max_timeout` timeout.
///
/// By default frame read rate is disabled.
pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
self.config.set_frame_read_rate(timeout, max_timeout, rate);
self
}

/// Set max allowed QoS.
///
/// If peer sends publish with higher qos then ProtocolError::MaxQoSViolated(..)
Expand Down
12 changes: 12 additions & 0 deletions src/v5/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ where
self
}

/// Set read rate parameters for single frame.
///
/// Set max timeout for reading single frame. If the client
/// sends `rate` amount of data, increase the timeout by 1 second for every.
/// But no more than `max_timeout` timeout.
///
/// By default frame read rate is disabled.
pub fn frame_read_rate(self, timeout: Seconds, max_timeout: Seconds, rate: u16) -> Self {
self.config.set_frame_read_rate(timeout, max_timeout, rate);
self
}

/// Set max inbound frame size.
///
/// If max size is set to `0`, size is unlimited.
Expand Down
72 changes: 69 additions & 3 deletions tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{cell::RefCell, future::Future, num::NonZeroU16, pin::Pin, rc::Rc, time

use ntex::service::{fn_service, Pipeline, ServiceFactory};
use ntex::time::{sleep, Millis, Seconds};
use ntex::util::{join_all, lazy, ByteString, Bytes, Ready};
use ntex::{server, service::chain_factory};
use ntex::util::{join_all, lazy, ByteString, Bytes, BytesMut, Ready};
use ntex::{codec::Encoder, server, service::chain_factory};

use ntex_mqtt::v3::{
client, codec, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish, Session,
Expand Down Expand Up @@ -447,7 +447,6 @@ fn ssl_acceptor() -> openssl::ssl::SslAcceptor {
#[ntex::test]
async fn test_large_publish_openssl() -> std::io::Result<()> {
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
env_logger::init();

let srv = server::test_server(move || {
chain_factory(server::openssl::Acceptor::new(ssl_acceptor()).map_err(|_| ())).and_then(
Expand Down Expand Up @@ -613,3 +612,70 @@ async fn test_sink_publish_noblock() -> std::io::Result<()> {
sink.close();
Ok(())
}

// Slow frame rate
#[ntex::test]
async fn test_frame_read_rate() -> std::io::Result<()> {
let _ = env_logger::try_init();
let check = Arc::new(AtomicBool::new(false));
let check2 = check.clone();

let srv = server::test_server(move || {
let check = check2.clone();

MqttServer::new(handshake)
.frame_read_rate(Seconds(1), Seconds(2), 10)
.publish(|_| Ready::Ok(()))
.control(move |msg| {
let check = check.clone();
match msg {
ControlMessage::ProtocolError(msg) => {
if msg.get_ref() == &ProtocolError::ReadTimeout {
check.store(true, Relaxed);
}
Ready::Ok(msg.ack())
}
_ => Ready::Ok(msg.disconnect()),
}
})
.finish()
.map_err(|_| ())
.map_init_err(|_| ())
});

let io = srv.connect().await.unwrap();
let codec = codec::Codec::default();
io.encode(codec::Connect::default().client_id("user").into(), &codec).unwrap();
io.recv(&codec).await.unwrap();

let p = codec::Publish {
dup: false,
retain: false,
qos: codec::QoS::AtLeastOnce,
topic: ByteString::from("test"),
packet_id: Some(NonZeroU16::new(3).unwrap()),
payload: Bytes::from(vec![b'*'; 270 * 1024]),
}
.into();

let mut buf = BytesMut::new();
codec.encode(p, &mut buf).unwrap();

io.write(&buf[..5]).unwrap();
buf.split_to(5);
sleep(Millis(100)).await;
io.write(&buf[..10]).unwrap();
buf.split_to(10);
sleep(Millis(500)).await;
assert!(!check.load(Relaxed));

io.write(&buf[..12]).unwrap();
buf.split_to(12);
sleep(Millis(500)).await;
assert!(!check.load(Relaxed));

sleep(Millis(1200)).await;
assert!(check.load(Relaxed));

Ok(())
}
Loading

0 comments on commit 2ecc983

Please sign in to comment.