Skip to content

Commit

Permalink
dead_letter_exchanges in RabbitMQ
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivier Michaud authored and infherny committed Nov 7, 2023
1 parent 76d3025 commit 7c02b1b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 14 deletions.
68 changes: 54 additions & 14 deletions src/saturn_engine/worker/topics/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ def from_content_type(cls, content_type: str) -> "RabbitMQSerializer | None":
}


@dataclasses.dataclass
class Exchange:
name: str
type: aio_pika.abc.ExchangeType = aio_pika.abc.ExchangeType.DIRECT
durable: bool = True
auto_delete: bool = False
exclusive: bool = False
passive: bool = False
arguments: dict[str, t.Any] = dataclasses.field(default_factory=dict)
timeout: int | float | None = None


class RabbitMQTopic(Topic):
"""A queue that consume message from RabbitMQ"""

Expand All @@ -72,6 +84,9 @@ class Options:
log_above_size: t.Optional[int] = None
max_publish_concurrency: int = 8
max_retry: int | None = None
arguments: dict[str, t.Any] = dataclasses.field(default_factory=dict)
exchange: Exchange | None = None
routing_key: str | None = None

class TopicServices:
rabbitmq: RabbitMQService
Expand All @@ -89,6 +104,15 @@ def __init__(self, options: Options, services: Services, **kwargs: object) -> No
self.attempt_by_message: LRUDefaultDict[str, int] = LRUDefaultDict(
cache_len=1024, default_factory=lambda: 0
)
self.queue_arguments: dict[str, t.Any] = self.options.arguments

if self.options.max_length:
self.queue_arguments.setdefault("x-max-length", self.options.max_length)
if self.options.max_length_bytes:
self.queue_arguments.setdefault(
"x-max-length-bytes", self.options.max_length_bytes
)
self.queue_arguments.setdefault("x-overflow", self.options.overflow)

async def run(self) -> AsyncGenerator[t.AsyncContextManager[TopicMessage], None]:
if self.is_closed:
Expand Down Expand Up @@ -130,19 +154,15 @@ async def publish(
body = self._serialize(message)
try:
await self.ensure_queue() # Ensure the queue is created.
channel = await self.channel
exchange = channel.default_exchange
if exchange is None:
raise ValueError("Channel has no exchange")

exchange = self.exchange
await exchange.publish(
aio_pika.Message(
body=body,
delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
content_type=self.options.serializer.content_type,
expiration=message.expire_after,
),
routing_key=self.options.queue_name,
routing_key=self.options.routing_key or self.options.queue_name,
)
return True
except aio_pika.exceptions.DeliveryError as e:
Expand Down Expand Up @@ -225,6 +245,25 @@ async def channel(self) -> aio_pika.abc.AbstractChannel:
channel.reopen_callbacks.add(self.channel_reopened)
return channel

@cached_property
async def exchange(self) -> aio_pika.abc.AbstractExchange:
channel = await self.channel
if self.options.exchange:
return await channel.declare_exchange(
name=self.options.exchange.name,
type=self.options.exchange.type,
durable=self.options.exchange.durable,
passive=self.options.exchange.passive,
exclusive=self.options.exchange.exclusive,
auto_delete=self.options.exchange.auto_delete,
arguments=self.options.exchange.arguments,
timeout=self.options.exchange.timeout,
)
return channel.default_exchange

async def ensure_exchange(self) -> aio_pika.abc.AbstractExchange:
return await self.exchange

def channel_closed(
self, channel: aio_pika.abc.AbstractChannel, reason: t.Optional[Exception]
) -> None:
Expand All @@ -241,20 +280,21 @@ def channel_reopened(self, channel: aio_pika.abc.AbstractChannel) -> None:

@cached_property
async def queue(self) -> aio_pika.abc.AbstractQueue:
arguments: dict[str, t.Any] = {}
if self.options.max_length:
arguments["x-max-length"] = self.options.max_length
if self.options.max_length_bytes:
arguments["x-max-length-bytes"] = self.options.max_length_bytes
arguments["x-overflow"] = self.options.overflow

channel = await self.channel
queue = await channel.declare_queue(
self.options.queue_name,
auto_delete=self.options.auto_delete,
durable=self.options.durable,
arguments=arguments,
arguments=self.queue_arguments,
)
await self.ensure_exchange()
if self.options.exchange:
await queue.bind(
self.options.exchange.name,
routing_key=self.options.routing_key or self.options.queue_name,
)
elif self.options.routing_key:
await queue.bind("", routing_key=self.options.routing_key)

return queue

Expand Down
44 changes: 44 additions & 0 deletions tests/worker/topics/test_rabbitmq_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,47 @@ async def test_retry(
assert message.id == "1"

await topic.close()


@pytest.mark.asyncio
async def test_dead_letter_exchanges(
rabbitmq_topic_maker: t.Callable[..., Awaitable[RabbitMQTopic]]
) -> None:
topic = await rabbitmq_topic_maker(
RabbitMQTopic,
serializer=RabbitMQSerializer.PICKLE,
arguments={
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": "dlx_queue",
},
)
dlx_topic = await rabbitmq_topic_maker(
RabbitMQTopic,
serializer=RabbitMQSerializer.PICKLE,
queue_name="dlx_queue",
)

await dlx_topic.ensure_queue()

messages = [
TopicMessage(id=MessageId("0"), args={"n": b"1", "time": utcnow()}),
]

for message in messages:
await topic.publish(message, wait=True)

# We make the message fail
async with alib.scoped_iter(topic.run()) as topic_iter:
context = await alib.anext(topic_iter)
with pytest.raises(ValueError):
async with context as message:
raise ValueError("Exception")

# We iter the dlx_topic, ensure the failed message
async with alib.scoped_iter(dlx_topic.run()) as dlx_topic_iter:
context = await alib.anext(dlx_topic_iter)
async with context as message:
assert message.id == "0"

await topic.close()
await dlx_topic.close()

0 comments on commit 7c02b1b

Please sign in to comment.