diff --git a/src/saturn_engine/worker/topics/rabbitmq.py b/src/saturn_engine/worker/topics/rabbitmq.py index 83843b92..1227b740 100644 --- a/src/saturn_engine/worker/topics/rabbitmq.py +++ b/src/saturn_engine/worker/topics/rabbitmq.py @@ -52,6 +52,18 @@ def from_content_type(cls, content_type: str) -> "RabbitMQSerializer | None": } +@dataclasses.dataclass +class Exchange: + name: str + type: aio_pika.abc.ExchangeType = aio_pika.abc.ExchangeType.DIRECT + durable: bool = True + auto_delete: bool = False + exclusive: bool = False + passive: bool = False + arguments: dict[str, t.Any] = dataclasses.field(default_factory=dict) + timeout: int | float | None = None + + class RabbitMQTopic(Topic): """A queue that consume message from RabbitMQ""" @@ -72,6 +84,9 @@ class Options: log_above_size: t.Optional[int] = None max_publish_concurrency: int = 8 max_retry: int | None = None + arguments: dict[str, t.Any] = dataclasses.field(default_factory=dict) + exchange: Exchange | None = None + routing_key: str | None = None class TopicServices: rabbitmq: RabbitMQService @@ -89,6 +104,15 @@ def __init__(self, options: Options, services: Services, **kwargs: object) -> No self.attempt_by_message: LRUDefaultDict[str, int] = LRUDefaultDict( cache_len=1024, default_factory=lambda: 0 ) + self.queue_arguments: dict[str, t.Any] = self.options.arguments + + if self.options.max_length: + self.queue_arguments.setdefault("x-max-length", self.options.max_length) + if self.options.max_length_bytes: + self.queue_arguments.setdefault( + "x-max-length-bytes", self.options.max_length_bytes + ) + self.queue_arguments.setdefault("x-overflow", self.options.overflow) async def run(self) -> AsyncGenerator[t.AsyncContextManager[TopicMessage], None]: if self.is_closed: @@ -130,11 +154,7 @@ async def publish( body = self._serialize(message) try: await self.ensure_queue() # Ensure the queue is created. - channel = await self.channel - exchange = channel.default_exchange - if exchange is None: - raise ValueError("Channel has no exchange") - + exchange = self.exchange await exchange.publish( aio_pika.Message( body=body, @@ -142,7 +162,7 @@ async def publish( content_type=self.options.serializer.content_type, expiration=message.expire_after, ), - routing_key=self.options.queue_name, + routing_key=self.options.routing_key or self.options.queue_name, ) return True except aio_pika.exceptions.DeliveryError as e: @@ -225,6 +245,25 @@ async def channel(self) -> aio_pika.abc.AbstractChannel: channel.reopen_callbacks.add(self.channel_reopened) return channel + @cached_property + async def exchange(self) -> aio_pika.abc.AbstractExchange: + channel = await self.channel + if self.options.exchange: + return await channel.declare_exchange( + name=self.options.exchange.name, + type=self.options.exchange.type, + durable=self.options.exchange.durable, + passive=self.options.exchange.passive, + exclusive=self.options.exchange.exclusive, + auto_delete=self.options.exchange.auto_delete, + arguments=self.options.exchange.arguments, + timeout=self.options.exchange.timeout, + ) + return channel.default_exchange + + async def ensure_exchange(self) -> aio_pika.abc.AbstractExchange: + return await self.exchange + def channel_closed( self, channel: aio_pika.abc.AbstractChannel, reason: t.Optional[Exception] ) -> None: @@ -241,20 +280,21 @@ def channel_reopened(self, channel: aio_pika.abc.AbstractChannel) -> None: @cached_property async def queue(self) -> aio_pika.abc.AbstractQueue: - arguments: dict[str, t.Any] = {} - if self.options.max_length: - arguments["x-max-length"] = self.options.max_length - if self.options.max_length_bytes: - arguments["x-max-length-bytes"] = self.options.max_length_bytes - arguments["x-overflow"] = self.options.overflow - channel = await self.channel queue = await channel.declare_queue( self.options.queue_name, auto_delete=self.options.auto_delete, durable=self.options.durable, - arguments=arguments, + arguments=self.queue_arguments, ) + await self.ensure_exchange() + if self.options.exchange: + await queue.bind( + self.options.exchange.name, + routing_key=self.options.routing_key or self.options.queue_name, + ) + elif self.options.routing_key: + await queue.bind("", routing_key=self.options.routing_key) return queue diff --git a/tests/worker/topics/test_rabbitmq_topic.py b/tests/worker/topics/test_rabbitmq_topic.py index 450946d1..f550ed03 100644 --- a/tests/worker/topics/test_rabbitmq_topic.py +++ b/tests/worker/topics/test_rabbitmq_topic.py @@ -252,3 +252,47 @@ async def test_retry( assert message.id == "1" await topic.close() + + +@pytest.mark.asyncio +async def test_dead_letter_exchanges( + rabbitmq_topic_maker: t.Callable[..., Awaitable[RabbitMQTopic]] +) -> None: + topic = await rabbitmq_topic_maker( + RabbitMQTopic, + serializer=RabbitMQSerializer.PICKLE, + arguments={ + "x-dead-letter-exchange": "", + "x-dead-letter-routing-key": "dlx_queue", + }, + ) + dlx_topic = await rabbitmq_topic_maker( + RabbitMQTopic, + serializer=RabbitMQSerializer.PICKLE, + queue_name="dlx_queue", + ) + + await dlx_topic.ensure_queue() + + messages = [ + TopicMessage(id=MessageId("0"), args={"n": b"1", "time": utcnow()}), + ] + + for message in messages: + await topic.publish(message, wait=True) + + # We make the message fail + async with alib.scoped_iter(topic.run()) as topic_iter: + context = await alib.anext(topic_iter) + with pytest.raises(ValueError): + async with context as message: + raise ValueError("Exception") + + # We iter the dlx_topic, ensure the failed message + async with alib.scoped_iter(dlx_topic.run()) as dlx_topic_iter: + context = await alib.anext(dlx_topic_iter) + async with context as message: + assert message.id == "0" + + await topic.close() + await dlx_topic.close()