diff --git a/src/server.rs b/src/server.rs index b0206fc..a6481b9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -51,12 +51,12 @@ impl Default impl MqttServer { /// Set client timeout reading protocol version. /// - /// Defines a timeout for reading `Connect` frame. If a client does not transmit + /// Defines a timeout for reading protocol version. If a client does not transmit /// version of the protocol within this time, the connection is terminated with /// Mqtt::Handshake(HandshakeError::Timeout) error. /// - /// By default, connect timeuot is 5 seconds. - pub fn connect_timeout(mut self, timeout: Seconds) -> Self { + /// By default, timeuot is 5 seconds. + pub fn protocol_version_timeout(mut self, timeout: Seconds) -> Self { self.connect_timeout = timeout.into(); self } diff --git a/src/v3/handshake.rs b/src/v3/handshake.rs index a469534..5ae0c71 100644 --- a/src/v3/handshake.rs +++ b/src/v3/handshake.rs @@ -7,7 +7,6 @@ use super::shared::MqttShared; use super::sink::MqttSink; const DEFAULT_KEEPALIVE: Seconds = Seconds(30); -const DEFAULT_OUTGOING_INFLIGHT: u16 = 16; /// Connect message pub struct Handshake { @@ -67,7 +66,7 @@ impl Handshake { keepalive, session_present, session: Some(st), - inflight: DEFAULT_OUTGOING_INFLIGHT, + outgoing: None, return_code: mqtt::ConnectAckReason::ConnectionAccepted, } } @@ -80,7 +79,7 @@ impl Handshake { session: None, session_present: false, keepalive: DEFAULT_KEEPALIVE, - inflight: DEFAULT_OUTGOING_INFLIGHT, + outgoing: None, return_code: mqtt::ConnectAckReason::IdentifierRejected, } } @@ -92,8 +91,8 @@ impl Handshake { shared: self.shared, session: None, session_present: false, + outgoing: None, keepalive: DEFAULT_KEEPALIVE, - inflight: DEFAULT_OUTGOING_INFLIGHT, return_code: mqtt::ConnectAckReason::BadUserNameOrPassword, } } @@ -105,8 +104,8 @@ impl Handshake { shared: self.shared, session: None, session_present: false, + outgoing: None, keepalive: DEFAULT_KEEPALIVE, - inflight: DEFAULT_OUTGOING_INFLIGHT, return_code: mqtt::ConnectAckReason::NotAuthorized, } } @@ -118,8 +117,8 @@ impl Handshake { shared: self.shared, session: None, session_present: false, + outgoing: None, keepalive: DEFAULT_KEEPALIVE, - inflight: DEFAULT_OUTGOING_INFLIGHT, return_code: mqtt::ConnectAckReason::ServiceUnavailable, } } @@ -139,7 +138,7 @@ pub struct HandshakeAck { pub(crate) return_code: mqtt::ConnectAckReason, pub(crate) shared: Rc, pub(crate) keepalive: Seconds, - pub(crate) inflight: u16, + pub(crate) outgoing: Option, } impl HandshakeAck { @@ -151,11 +150,11 @@ impl HandshakeAck { self } - /// Number of outgoing in-flight concurrent messages. + /// Number of outgoing concurrent messages. /// - /// By default in-flight is set to 16 messages - pub fn inflight(mut self, val: u16) -> Self { - self.inflight = val; + /// By default outgoing is set to 16 messages + pub fn max_outgoing(mut self, val: u16) -> Self { + self.outgoing = Some(val); self } } diff --git a/src/v3/server.rs b/src/v3/server.rs index 0655582..bc4c0a3 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -48,6 +48,8 @@ pub struct MqttServer { max_size: u32, max_inflight: u16, max_inflight_size: usize, + max_outgoing: u16, + max_outgoing_size: (u32, u32), handle_qos_after_disconnect: Option, connect_timeout: Seconds, config: DispatcherConfig, @@ -79,6 +81,8 @@ where max_size: 0, max_inflight: 16, max_inflight_size: 65535, + max_outgoing: 16, + max_outgoing_size: (65535, 512), handle_qos_after_disconnect: None, connect_timeout: Seconds::ZERO, pool: Default::default(), @@ -154,7 +158,7 @@ where /// Number of in-flight concurrent messages. /// /// By default in-flight is set to 16 messages - pub fn inflight(mut self, val: u16) -> Self { + pub fn max_inflight(mut self, val: u16) -> Self { self.max_inflight = val; self } @@ -162,11 +166,27 @@ where /// Total size of in-flight messages. /// /// By default total in-flight size is set to 64Kb - pub fn inflight_size(mut self, val: usize) -> Self { + pub fn max_inflight_size(mut self, val: usize) -> Self { self.max_inflight_size = val; self } + /// Number of outgoing concurrent messages. + /// + /// By default outgoing is set to 16 messages + pub fn max_outgoing(mut self, val: u16) -> Self { + self.max_outgoing = val; + self + } + + /// Total size of outgoing messages. + /// + /// By default total outgoing size is set to 64Kb + pub fn max_outgoing_size(mut self, val: u32) -> Self { + self.max_outgoing_size = (val, val / 10); + self + } + /// Handle max received QoS messages after client disconnect. /// /// By default, messages received before dispatched to the publish service will be dropped if @@ -212,6 +232,8 @@ where max_size: self.max_size, max_inflight: self.max_inflight, max_inflight_size: self.max_inflight_size, + max_outgoing: self.max_outgoing, + max_outgoing_size: self.max_outgoing_size, handle_qos_after_disconnect: self.handle_qos_after_disconnect, connect_timeout: self.connect_timeout, pool: self.pool, @@ -235,8 +257,10 @@ where max_size: self.max_size, max_inflight: self.max_inflight, max_inflight_size: self.max_inflight_size, - connect_timeout: self.connect_timeout, + max_outgoing: self.max_outgoing, + max_outgoing_size: self.max_outgoing_size, handle_qos_after_disconnect: self.handle_qos_after_disconnect, + connect_timeout: self.connect_timeout, pool: self.pool, _t: PhantomData, } @@ -266,6 +290,8 @@ where HandshakeFactory { factory: self.handshake, max_size: self.max_size, + max_outgoing: self.max_outgoing, + max_outgoing_size: self.max_outgoing_size, connect_timeout: self.connect_timeout, pool: self.pool.clone(), _t: PhantomData, @@ -286,6 +312,8 @@ where struct HandshakeFactory { factory: H, max_size: u32, + max_outgoing: u16, + max_outgoing_size: (u32, u32), connect_timeout: Seconds, pool: Rc, _t: PhantomData, @@ -305,6 +333,8 @@ where async fn create(&self, _: ()) -> Result { Ok(HandshakeService { max_size: self.max_size, + max_outgoing: self.max_outgoing, + max_outgoing_size: self.max_outgoing_size, pool: self.pool.clone(), service: self.factory.create(()).await?, connect_timeout: self.connect_timeout.into(), @@ -316,6 +346,8 @@ where struct HandshakeService { service: H, max_size: u32, + max_outgoing: u16, + max_outgoing_size: (u32, u32), pool: Rc, connect_timeout: Millis, _t: PhantomData, @@ -339,6 +371,9 @@ where ) -> Result { log::trace!("Starting mqtt v3 handshake"); + let (h, l) = self.max_outgoing_size; + io.memory_pool().set_write_params(h, l); + let codec = mqtt::Codec::default(); codec.set_max_size(self.max_size); let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone())); @@ -373,7 +408,7 @@ where log::trace!("Sending success handshake ack: {:#?}", pkt); - ack.shared.set_cap(ack.inflight as usize); + ack.shared.set_cap(ack.outgoing.unwrap_or(self.max_outgoing) as usize); ack.io.encode(pkt, &ack.shared.codec)?; Ok(( ack.io, diff --git a/tests/test_server_v5.rs b/tests/test_server_v5.rs index 342f511..a42a702 100644 --- a/tests/test_server_v5.rs +++ b/tests/test_server_v5.rs @@ -1070,10 +1070,10 @@ async fn handle_or_drop_publish_after_disconnect( ) .unwrap(); io.flush(true).await.unwrap(); - sleep(Millis(1750)).await; + sleep(Millis(2750)).await; io.close(); drop(io); - sleep(Millis(1000)).await; + sleep(Millis(1500)).await; assert!(disconnect.load(Relaxed));