-
-
Notifications
You must be signed in to change notification settings - Fork 297
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FIX] fastapi: Avoid process stuck in case of retry of Post request w…
…ith body content. In case of retry we must ensure that the stream pass to the Fastapi application is reset to the beginning to be sure it can be consumed again. Unfortunately , the stream object from the werkzeug request is not always seekable. In such a case, we wrap the stream into a new SeekableStream object that it become possible to reset the stream at the begining without having to read the stream first into memory.
- Loading branch information
Showing
8 changed files
with
255 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Fix issue with the retry of a POST request with a body content. | ||
|
||
Prior to this fix the retry of a POST request with a body content would | ||
stuck in a loop and never complete. This was due to the fact that the | ||
request input stream was not reset after a failed attempt to process the | ||
request. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright 2024 ACSONE SA/NV | ||
# License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL). | ||
import io | ||
|
||
|
||
class SeekableStream(io.RawIOBase): | ||
"""A seekable stream that wraps another stream and buffers read data. | ||
This class allows to seek and read data from the original stream. | ||
It buffers read data to allow seeking back to read data again. | ||
This class is useful to handle the case where the original stream does not | ||
support seeking, but the data could eventually be read multiple times. | ||
To avoid reading the original stream a first time to buffer the data, we | ||
buffer the data as it is read. In this way we do not add delay when the | ||
data is read only once. | ||
""" | ||
|
||
def __init__(self, original_stream): | ||
super().__init__() | ||
self.original_stream = original_stream | ||
self.buffer = bytearray() | ||
self.buffer_position = 0 | ||
self.seek_position = 0 | ||
self.end_of_stream = False | ||
|
||
def read(self, size=-1): # pylint: disable=method-required-super | ||
if size == -1: | ||
# Read all remaining data | ||
size = len(self.buffer) - self.buffer_position | ||
data_from_buffer = bytes(self.buffer[self.buffer_position :]) | ||
self.buffer_position = len(self.buffer) | ||
|
||
# Read remaining data from the original stream if not already buffered | ||
remaining_data = self.original_stream.read() | ||
self.buffer.extend(remaining_data) | ||
self.end_of_stream = True | ||
return data_from_buffer + remaining_data | ||
|
||
buffer_len = len(self.buffer) | ||
remaining_buffer = buffer_len - self.buffer_position | ||
|
||
if remaining_buffer >= size: | ||
# Read from the buffer if there is enough data | ||
data = self.buffer[self.buffer_position : self.buffer_position + size] | ||
self.buffer_position += size | ||
return bytes(data) | ||
else: | ||
# Read remaining buffer data | ||
data = self.buffer[self.buffer_position :] | ||
self.buffer_position = buffer_len | ||
|
||
# Read the rest from the original stream | ||
additional_data = self.original_stream.read(size - remaining_buffer) | ||
if additional_data is None: | ||
additional_data = b"" | ||
|
||
# Store read data in the buffer | ||
self.buffer.extend(additional_data) | ||
self.buffer_position += len(additional_data) | ||
if len(additional_data) < (size - remaining_buffer): | ||
self.end_of_stream = True | ||
return bytes(data + additional_data) | ||
|
||
def seek(self, offset, whence=io.SEEK_SET): | ||
if whence == io.SEEK_SET: | ||
new_position = offset | ||
elif whence == io.SEEK_CUR: | ||
new_position = self.buffer_position + offset | ||
elif whence == io.SEEK_END: | ||
if not self.end_of_stream: | ||
# Read the rest of the stream to buffer it | ||
# This is needed to know the total size of the stream | ||
self.read() | ||
new_position = len(self.buffer) + offset | ||
|
||
if new_position < 0: | ||
raise ValueError("Negative seek position {}".format(new_position)) | ||
|
||
if new_position <= len(self.buffer): | ||
self.buffer_position = new_position | ||
else: | ||
# Read from the original stream to fill the buffer up to the new position | ||
to_read = new_position - len(self.buffer) | ||
additional_data = self.original_stream.read(to_read) | ||
if additional_data is None: | ||
additional_data = b"" | ||
self.buffer.extend(additional_data) | ||
if len(self.buffer) < new_position: | ||
raise io.UnsupportedOperation( | ||
"Cannot seek beyond the end of the stream" | ||
) | ||
self.buffer_position = new_position | ||
|
||
return self.buffer_position | ||
|
||
def tell(self): | ||
return self.buffer_position | ||
|
||
def readable(self): | ||
return True | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from . import test_fastapi | ||
from . import test_fastapi_demo | ||
from . import test_seekable_stream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import io | ||
import random | ||
|
||
from odoo.tests.common import TransactionCase | ||
|
||
from ..seekable_stream import SeekableStream | ||
|
||
|
||
class TestSeekableStream(TransactionCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
# create a random large content | ||
cls.original_content = random.randbytes(1024 * 1024) | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.original_stream = NonSeekableStream(self.original_content) | ||
|
||
def test_read_all(self): | ||
self.assertFalse(self.original_stream.seekable()) | ||
stream = SeekableStream(self.original_stream) | ||
data = stream.read() | ||
self.assertEqual(data, self.original_content) | ||
stream.seek(0) | ||
data = stream.read() | ||
self.assertEqual(data, self.original_content) | ||
|
||
def test_read_partial(self): | ||
self.assertFalse(self.original_stream.seekable()) | ||
stream = SeekableStream(self.original_stream) | ||
data = stream.read(10) | ||
self.assertEqual(data, self.original_content[:10]) | ||
data = stream.read(10) | ||
self.assertEqual(data, self.original_content[10:20]) | ||
# read the rest | ||
data = stream.read() | ||
self.assertEqual(data, self.original_content[20:]) | ||
|
||
def test_seek(self): | ||
self.assertFalse(self.original_stream.seekable()) | ||
stream = SeekableStream(self.original_stream) | ||
stream.seek(10) | ||
self.assertEqual(stream.tell(), 10) | ||
data = stream.read(10) | ||
self.assertEqual(data, self.original_content[10:20]) | ||
stream.seek(0) | ||
self.assertEqual(stream.tell(), 0) | ||
data = stream.read(10) | ||
self.assertEqual(data, self.original_content[:10]) | ||
|
||
def test_seek_relative(self): | ||
self.assertFalse(self.original_stream.seekable()) | ||
stream = SeekableStream(self.original_stream) | ||
stream.seek(10) | ||
self.assertEqual(stream.tell(), 10) | ||
stream.seek(5, io.SEEK_CUR) | ||
self.assertEqual(stream.tell(), 15) | ||
data = stream.read(10) | ||
self.assertEqual(data, self.original_content[15:25]) | ||
|
||
def test_seek_end(self): | ||
self.assertFalse(self.original_stream.seekable()) | ||
stream = SeekableStream(self.original_stream) | ||
stream.seek(-10, io.SEEK_END) | ||
self.assertEqual(stream.tell(), len(self.original_content) - 10) | ||
data = stream.read(10) | ||
self.assertEqual(data, self.original_content[-10:]) | ||
stream.seek(0, io.SEEK_END) | ||
self.assertEqual(stream.tell(), len(self.original_content)) | ||
data = stream.read(10) | ||
self.assertEqual(data, b"") | ||
stream.seek(-len(self.original_content), io.SEEK_END) | ||
self.assertEqual(stream.tell(), 0) | ||
data = stream.read(10) | ||
|
||
|
||
class NonSeekableStream(io.BytesIO): | ||
def seekable(self): | ||
return False | ||
|
||
def seek(self, offset, whence=io.SEEK_SET): | ||
raise io.UnsupportedOperation("seek") | ||
|
||
def tell(self): | ||
raise io.UnsupportedOperation("tell") | ||