Skip to content

Commit

Permalink
adlfs: add support for timeout/connection_timeout/read_timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop committed Oct 10, 2023
1 parent 01e91b1 commit 7e02b10
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
25 changes: 25 additions & 0 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ class AzureBlobFileSystem(AsyncFileSystem):
max_concurrency:
The number of concurrent connections to use when uploading or downloading a blob.
If None it will be inferred from fsspec.asyn._get_batch_size().
timeout: int
Sets the server-side timeout when uploading or downloading a blob.
connection_timeout: int
The number of seconds the client will wait to establish a connection to the server
when uploading or downloading a blob.
read_timeout: int
The number of seconds the client will wait, between consecutive read operations,
for a response from the server while uploading or downloading a blob.
Pass on to fsspec:
Expand Down Expand Up @@ -237,6 +245,9 @@ def __init__(
version_aware: bool = False,
assume_container_exists: Optional[bool] = None,
max_concurrency: Optional[int] = None,
timeout: Optional[int] = None,
connection_timeout: Optional[int] = None,
read_timeout: Optional[int] = None,
**kwargs,
):
super_kwargs = {
Expand Down Expand Up @@ -270,6 +281,15 @@ def __init__(
self.default_fill_cache = default_fill_cache
self.default_cache_type = default_cache_type
self.version_aware = version_aware

self._timeout_kwargs = {}
if timeout is not None:
self._timeout_kwargs["timeout"] = timeout
if connection_timeout is not None:
self._timeout_kwargs["connection_timeout"] = connection_timeout
if read_timeout is not None:
self._timeout_kwargs["read_timeout"] = read_timeout

if (
self.credential is None
and self.account_key is None
Expand Down Expand Up @@ -1347,6 +1367,7 @@ async def _pipe_file(
overwrite=overwrite,
metadata={"is_directory": "false"},
max_concurrency=max_concurrency or self.max_concurrency,
**self._timeout_kwargs,
**kwargs,
)
self.invalidate_cache(self._parent(path))
Expand Down Expand Up @@ -1379,6 +1400,7 @@ async def _cat_file(
length=length,
version_id=version_id,
max_concurrency=max_concurrency or self.max_concurrency,
**self._timeout_kwargs,
)
except ResourceNotFoundError as e:
raise FileNotFoundError from e
Expand Down Expand Up @@ -1557,6 +1579,7 @@ async def _put_file(
"upload_stream_current", callback
),
max_concurrency=max_concurrency or self.max_concurrency,
**self._timeout_kwargs,
)
self.invalidate_cache()
except ResourceExistsError:
Expand Down Expand Up @@ -1633,6 +1656,7 @@ async def _get_file(
),
version_id=version_id,
max_concurrency=max_concurrency or self.max_concurrency,
**self._timeout_kwargs,
)
with open(lpath, "wb") as my_blob:
await stream.readinto(my_blob)
Expand Down Expand Up @@ -2048,6 +2072,7 @@ async def _async_upload_chunk(self, final: bool = False, **kwargs):
length=length,
blob_type=BlobType.AppendBlob,
metadata=self.metadata,
**self.fs._timeout_kwargs,
)
else:
raise ValueError(
Expand Down
103 changes: 103 additions & 0 deletions adlfs/tests/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,3 +1757,106 @@ async def test_get_file_versioned(storage, mocker, tmp_path):
with pytest.raises(FileNotFoundError):
await fs._get_file("data/root/a/file.txt?versionid=invalid_version", dest)
assert not dest.exists()


async def test_cat_file_timeout(storage, mocker):
from azure.storage.blob.aio import BlobClient

fs = AzureBlobFileSystem(
account_name=storage.account_name,
connection_string=CONN_STR,
skip_instance_cache=True,
timeout=11,
connection_timeout=12,
read_timeout=13,
)
download_blob = mocker.patch.object(BlobClient, "download_blob")

await fs._cat_file("data/root/a/file.txt")
download_blob.assert_called_once_with(
offset=None,
length=None,
max_concurrency=fs.max_concurrency,
version_id=None,
timeout=11,
connection_timeout=12,
read_timeout=13,
)


async def test_get_file_timeout(storage, mocker, tmp_path):
from azure.storage.blob.aio import BlobClient

fs = AzureBlobFileSystem(
account_name=storage.account_name,
connection_string=CONN_STR,
skip_instance_cache=True,
timeout=11,
connection_timeout=12,
read_timeout=13,
)
download_blob = mocker.patch.object(BlobClient, "download_blob")

await fs._get_file("data/root/a/file.txt", str(tmp_path / "out"))
download_blob.assert_called_once_with(
offset=None,
length=None,
max_concurrency=fs.max_concurrency,
version_id=None,
timeout=11,
connection_timeout=12,
read_timeout=13,
)


async def test_pipe_file_timeout(storage, mocker):
from azure.storage.blob.aio import BlobClient

fs = AzureBlobFileSystem(
account_name=storage.account_name,
connection_string=CONN_STR,
skip_instance_cache=True,
timeout=11,
connection_timeout=12,
read_timeout=13,
)
upload_blob = mocker.patch.object(BlobClient, "upload_blob")

await fs._pipe_file("pipefiletimeout", b"data")
upload_blob.assert_called_once_with(
offset=None,
length=None,
max_concurrency=fs.max_concurrency,
version_id=None,
timeout=11,
connection_timeout=12,
read_timeout=13,
)


async def test_put_file_timeout(storage, mocker, tmp_path):
from azure.storage.blob.aio import BlobClient

fs = AzureBlobFileSystem(
account_name=storage.account_name,
connection_string=CONN_STR,
skip_instance_cache=True,
timeout=11,
connection_timeout=12,
read_timeout=13,
)
upload_blob = mocker.patch.object(BlobClient, "upload_blob")

src = tmp_path / "putfiletimeout"
src.write_bytes(b"data")

await fs._pipe_file(str(src), "putfiletimeout")
upload_blob.assert_called_once_with(
offset=None,
length=None,
max_concurrency=fs.max_concurrency,
version_id=None,
timeout=11,
connection_timeout=12,
read_timeout=13,
)

0 comments on commit 7e02b10

Please sign in to comment.