Skip to content

Commit

Permalink
[FIX] fastapi: Avoid process stuck in case of retry of Post request w…
Browse files Browse the repository at this point in the history
…ith body content.

In case of retry we must ensure that the stream pass to the Fastapi application is reset to the beginning to be sure it can be consumed again. Unfortunately , the stream object from the werkzeug request is not always seekable. In such a case, we wrap the stream into a new SeekableStream object that it become possible to reset the stream at the begining without having to read the stream first into memory.
  • Loading branch information
lmignon committed Jun 26, 2024
1 parent a64cfa7 commit e871a22
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 7 deletions.
18 changes: 17 additions & 1 deletion fastapi/fastapi_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fastapi.utils import is_body_allowed_for_status_code

from .context import odoo_env_ctx
from .seekable_stream import SeekableStream


class FastApiDispatcher(Dispatcher):
Expand Down Expand Up @@ -112,7 +113,22 @@ def _get_environ(self):
# date odoo version. (EAFP: Easier to Ask for Forgiveness than Permission)
httprequest = self.request.httprequest
environ = httprequest.environ
environ["wsgi.input"] = httprequest._get_stream_for_parsing()
stream = httprequest._get_stream_for_parsing()
# Check if the stream supports seeking
if hasattr(stream, "seekable") and stream.seekable():
# Reset the stream to the beginning to ensure it can be consumed
# again by the application in case of a retry mechanism
stream.seek(0)
else:
# If the stream does not support seeking, we need wrap it
# in a SeekableStream object that will buffer the data read
# from the stream. This way we can seek back to the beginning
# of the stream to read the data again if needed.
if not hasattr(httprequest, "_cached_stream"):
httprequest._cached_stream = SeekableStream(stream)
stream = httprequest._cached_stream
stream.seek(0)
environ["wsgi.input"] = stream
return environ

@contextmanager
Expand Down
6 changes: 6 additions & 0 deletions fastapi/readme/newsfragments/440.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fix issue with the retry of a POST request with a body content.

Prior to this fix the retry of a POST request with a body content would
stuck in a loop and never complete. This was due to the fact that the
request input stream was not reset after a failed attempt to process the
request.
27 changes: 25 additions & 2 deletions fastapi/routers/demo_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from odoo.addons.base.models.res_partner import Partner

from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, File, HTTPException, Query, status
from fastapi.responses import JSONResponse

from ..dependencies import authenticated_partner, fastapi_endpoint, odoo_env
from ..models import FastapiEndpoint
Expand Down Expand Up @@ -95,7 +96,7 @@ async def endpoint_app_info(

@router.get("/demo/retrying")
async def retrying(
nbr_retries: Annotated[int, Query(gt=1, lt=MAX_TRIES_ON_CONCURRENCY_FAILURE)]
nbr_retries: Annotated[int, Query(gt=1, lt=MAX_TRIES_ON_CONCURRENCY_FAILURE)],
) -> int:
"""This method is used in the test suite to check that the retrying
functionality in case of concurrency error on the database is working
Expand All @@ -114,6 +115,28 @@ async def retrying(
return tryno


@router.post("/demo/retrying")
async def retrying_post(
nbr_retries: Annotated[int, Query(gt=1, lt=MAX_TRIES_ON_CONCURRENCY_FAILURE)],
file: Annotated[bytes, File()],
) -> JSONResponse:
"""This method is used in the test suite to check that the retrying
functionality in case of concurrency error on the database is working
correctly for retryable exceptions.
The output will be the number of retries that have been done.
This method is mainly used to test the retrying functionality
"""
global _CPT
if _CPT < nbr_retries:
_CPT += 1
raise FakeConcurrentUpdateError("fake error")
tryno = _CPT
_CPT = 0
return JSONResponse(content={"retries": tryno, "file": file.decode("utf-8")})


class FakeConcurrentUpdateError(OperationalError):
@property
def pgcode(self):
Expand Down
101 changes: 101 additions & 0 deletions fastapi/seekable_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 ACSONE SA/NV
# License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL).
import io


class SeekableStream(io.RawIOBase):
"""A seekable stream that wraps another stream and buffers read data.
This class allows to seek and read data from the original stream.
It buffers read data to allow seeking back to read data again.
This class is useful to handle the case where the original stream does not
support seeking, but the data could eventually be read multiple times.
To avoid reading the original stream a first time to buffer the data, we
buffer the data as it is read. In this way we do not add delay when the
data is read only once.
"""

def __init__(self, original_stream):
super().__init__()
self.original_stream = original_stream
self.buffer = bytearray()
self.buffer_position = 0
self.seek_position = 0
self.end_of_stream = False

def read(self, size=-1): # pylint: disable=method-required-super
if size == -1:
# Read all remaining data
size = len(self.buffer) - self.buffer_position
data_from_buffer = bytes(self.buffer[self.buffer_position :])
self.buffer_position = len(self.buffer)

# Read remaining data from the original stream if not already buffered
remaining_data = self.original_stream.read()
self.buffer.extend(remaining_data)
self.end_of_stream = True
return data_from_buffer + remaining_data

buffer_len = len(self.buffer)
remaining_buffer = buffer_len - self.buffer_position

if remaining_buffer >= size:
# Read from the buffer if there is enough data
data = self.buffer[self.buffer_position : self.buffer_position + size]
self.buffer_position += size
return bytes(data)
else:
# Read remaining buffer data
data = self.buffer[self.buffer_position :]
self.buffer_position = buffer_len

# Read the rest from the original stream
additional_data = self.original_stream.read(size - remaining_buffer)
if additional_data is None:
additional_data = b""

Check warning on line 56 in fastapi/seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/seekable_stream.py#L56

Added line #L56 was not covered by tests

# Store read data in the buffer
self.buffer.extend(additional_data)
self.buffer_position += len(additional_data)
if len(additional_data) < (size - remaining_buffer):
self.end_of_stream = True
return bytes(data + additional_data)

def seek(self, offset, whence=io.SEEK_SET):
if whence == io.SEEK_SET:
new_position = offset
elif whence == io.SEEK_CUR:
new_position = self.buffer_position + offset
elif whence == io.SEEK_END:
if not self.end_of_stream:
# Read the rest of the stream to buffer it
# This is needed to know the total size of the stream
self.read()
new_position = len(self.buffer) + offset

if new_position < 0:
raise ValueError("Negative seek position {}".format(new_position))

Check warning on line 78 in fastapi/seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/seekable_stream.py#L78

Added line #L78 was not covered by tests

if new_position <= len(self.buffer):
self.buffer_position = new_position
else:
# Read from the original stream to fill the buffer up to the new position
to_read = new_position - len(self.buffer)
additional_data = self.original_stream.read(to_read)
if additional_data is None:
additional_data = b""

Check warning on line 87 in fastapi/seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/seekable_stream.py#L87

Added line #L87 was not covered by tests
self.buffer.extend(additional_data)
if len(self.buffer) < new_position:
raise io.UnsupportedOperation(

Check warning on line 90 in fastapi/seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/seekable_stream.py#L90

Added line #L90 was not covered by tests
"Cannot seek beyond the end of the stream"
)
self.buffer_position = new_position

return self.buffer_position

def tell(self):
return self.buffer_position

def readable(self):
return True

Check warning on line 101 in fastapi/seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/seekable_stream.py#L101

Added line #L101 was not covered by tests
11 changes: 7 additions & 4 deletions fastapi/static/description/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

/*
:Author: David Goodger (goodger@python.org)
:Id: $Id: html4css1.css 8954 2022-01-20 10:10:25Z milde $
:Id: $Id: html4css1.css 9511 2024-01-13 09:50:07Z milde $
:Copyright: This stylesheet has been placed in the public domain.

Default cascading style sheet for the HTML output of Docutils.
Despite the name, some widely supported CSS2 features are used.

See https://docutils.sourceforge.io/docs/howto/html-stylesheets.html for how to
customize this style sheet.
Expand Down Expand Up @@ -274,7 +275,7 @@
margin-left: 2em ;
margin-right: 2em }

pre.code .ln { color: grey; } /* line numbers */
pre.code .ln { color: gray; } /* line numbers */
pre.code, code { background-color: #eeeeee }
pre.code .comment, code .comment { color: #5C6576 }
pre.code .keyword, code .keyword { color: #3B0D06; font-weight: bold }
Expand All @@ -300,7 +301,7 @@
span.pre {
white-space: pre }

span.problematic {
span.problematic, pre.problematic {
color: red }

span.section-subtitle {
Expand Down Expand Up @@ -1794,7 +1795,9 @@ <h2><a class="toc-backref" href="#toc-entry-33">Contributors</a></h2>
<div class="section" id="maintainers">
<h2><a class="toc-backref" href="#toc-entry-34">Maintainers</a></h2>
<p>This module is maintained by the OCA.</p>
<a class="reference external image-reference" href="https://odoo-community.org"><img alt="Odoo Community Association" src="https://odoo-community.org/logo.png" /></a>
<a class="reference external image-reference" href="https://odoo-community.org">
<img alt="Odoo Community Association" src="https://odoo-community.org/logo.png" />
</a>
<p>OCA, or the Odoo Community Association, is a nonprofit organization whose
mission is to support the collaborative development of Odoo features and
promote its widespread use.</p>
Expand Down
1 change: 1 addition & 0 deletions fastapi/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import test_fastapi
from . import test_fastapi_demo
from . import test_seekable_stream
12 changes: 12 additions & 0 deletions fastapi/tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def test_retrying(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(int(response.content), nbr_retries)

def test_retrying_post(self):
"""Test that the retrying mechanism is working as expected with the
FastAPI endpoints in case of POST request with a file.
"""
nbr_retries = 3
route = f"/fastapi_demo/demo/retrying?nbr_retries={nbr_retries}"
response = self.url_open(
route, timeout=20, files={"file": ("test.txt", b"test")}
)
self.assertEqual(response.status_code, 200)
self.assertDictEqual(response.json(), {"retries": nbr_retries, "file": "test"})

@mute_logger("odoo.http")
def assert_exception_processed(
self,
Expand Down
86 changes: 86 additions & 0 deletions fastapi/tests/test_seekable_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import io
import random

from odoo.tests.common import TransactionCase

from ..seekable_stream import SeekableStream


class TestSeekableStream(TransactionCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# create a random large content
cls.original_content = random.randbytes(1024 * 1024)

def setUp(self):
super().setUp()
self.original_stream = NonSeekableStream(self.original_content)

def test_read_all(self):
self.assertFalse(self.original_stream.seekable())
stream = SeekableStream(self.original_stream)
data = stream.read()
self.assertEqual(data, self.original_content)
stream.seek(0)
data = stream.read()
self.assertEqual(data, self.original_content)

def test_read_partial(self):
self.assertFalse(self.original_stream.seekable())
stream = SeekableStream(self.original_stream)
data = stream.read(10)
self.assertEqual(data, self.original_content[:10])
data = stream.read(10)
self.assertEqual(data, self.original_content[10:20])
# read the rest
data = stream.read()
self.assertEqual(data, self.original_content[20:])

def test_seek(self):
self.assertFalse(self.original_stream.seekable())
stream = SeekableStream(self.original_stream)
stream.seek(10)
self.assertEqual(stream.tell(), 10)
data = stream.read(10)
self.assertEqual(data, self.original_content[10:20])
stream.seek(0)
self.assertEqual(stream.tell(), 0)
data = stream.read(10)
self.assertEqual(data, self.original_content[:10])

def test_seek_relative(self):
self.assertFalse(self.original_stream.seekable())
stream = SeekableStream(self.original_stream)
stream.seek(10)
self.assertEqual(stream.tell(), 10)
stream.seek(5, io.SEEK_CUR)
self.assertEqual(stream.tell(), 15)
data = stream.read(10)
self.assertEqual(data, self.original_content[15:25])

def test_seek_end(self):
self.assertFalse(self.original_stream.seekable())
stream = SeekableStream(self.original_stream)
stream.seek(-10, io.SEEK_END)
self.assertEqual(stream.tell(), len(self.original_content) - 10)
data = stream.read(10)
self.assertEqual(data, self.original_content[-10:])
stream.seek(0, io.SEEK_END)
self.assertEqual(stream.tell(), len(self.original_content))
data = stream.read(10)
self.assertEqual(data, b"")
stream.seek(-len(self.original_content), io.SEEK_END)
self.assertEqual(stream.tell(), 0)
data = stream.read(10)


class NonSeekableStream(io.BytesIO):
def seekable(self):
return False

def seek(self, offset, whence=io.SEEK_SET):
raise io.UnsupportedOperation("seek")

Check warning on line 83 in fastapi/tests/test_seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/tests/test_seekable_stream.py#L83

Added line #L83 was not covered by tests

def tell(self):
raise io.UnsupportedOperation("tell")

Check warning on line 86 in fastapi/tests/test_seekable_stream.py

View check run for this annotation

Codecov / codecov/patch

fastapi/tests/test_seekable_stream.py#L86

Added line #L86 was not covered by tests

0 comments on commit e871a22

Please sign in to comment.