diff --git a/adlfs/spec.py b/adlfs/spec.py index 8c3d4cc1..cf9993dc 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -173,6 +173,14 @@ class AzureBlobFileSystem(AsyncFileSystem): max_concurrency: The number of concurrent connections to use when uploading or downloading a blob. If None it will be inferred from fsspec.asyn._get_batch_size(). + timeout: int + Sets the server-side timeout when uploading or downloading a blob. + connection_timeout: int + The number of seconds the client will wait to establish a connection to the server + when uploading or downloading a blob. + read_timeout: int + The number of seconds the client will wait, between consecutive read operations, + for a response from the server while uploading or downloading a blob. Pass on to fsspec: @@ -237,6 +245,9 @@ def __init__( version_aware: bool = False, assume_container_exists: Optional[bool] = None, max_concurrency: Optional[int] = None, + timeout: Optional[int] = None, + connection_timeout: Optional[int] = None, + read_timeout: Optional[int] = None, **kwargs, ): super_kwargs = { @@ -270,6 +281,15 @@ def __init__( self.default_fill_cache = default_fill_cache self.default_cache_type = default_cache_type self.version_aware = version_aware + + self._timeout_kwargs = {} + if timeout is not None: + self._timeout_kwargs["timeout"] = timeout + if connection_timeout is not None: + self._timeout_kwargs["connection_timeout"] = connection_timeout + if read_timeout is not None: + self._timeout_kwargs["read_timeout"] = read_timeout + if ( self.credential is None and self.account_key is None @@ -1347,6 +1367,7 @@ async def _pipe_file( overwrite=overwrite, metadata={"is_directory": "false"}, max_concurrency=max_concurrency or self.max_concurrency, + **self._timeout_kwargs, **kwargs, ) self.invalidate_cache(self._parent(path)) @@ -1379,6 +1400,7 @@ async def _cat_file( length=length, version_id=version_id, max_concurrency=max_concurrency or self.max_concurrency, + **self._timeout_kwargs, ) except ResourceNotFoundError as e: raise FileNotFoundError from e @@ -1557,6 +1579,7 @@ async def _put_file( "upload_stream_current", callback ), max_concurrency=max_concurrency or self.max_concurrency, + **self._timeout_kwargs, ) self.invalidate_cache() except ResourceExistsError: @@ -1633,6 +1656,7 @@ async def _get_file( ), version_id=version_id, max_concurrency=max_concurrency or self.max_concurrency, + **self._timeout_kwargs, ) with open(lpath, "wb") as my_blob: await stream.readinto(my_blob) @@ -2048,6 +2072,7 @@ async def _async_upload_chunk(self, final: bool = False, **kwargs): length=length, blob_type=BlobType.AppendBlob, metadata=self.metadata, + **self.fs._timeout_kwargs, ) else: raise ValueError( diff --git a/adlfs/tests/test_spec.py b/adlfs/tests/test_spec.py index 24b0300a..307282f0 100644 --- a/adlfs/tests/test_spec.py +++ b/adlfs/tests/test_spec.py @@ -1757,3 +1757,106 @@ async def test_get_file_versioned(storage, mocker, tmp_path): with pytest.raises(FileNotFoundError): await fs._get_file("data/root/a/file.txt?versionid=invalid_version", dest) assert not dest.exists() + + +async def test_cat_file_timeout(storage, mocker): + from azure.storage.blob.aio import BlobClient + + fs = AzureBlobFileSystem( + account_name=storage.account_name, + connection_string=CONN_STR, + skip_instance_cache=True, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + download_blob = mocker.patch.object(BlobClient, "download_blob") + + await fs._cat_file("data/root/a/file.txt") + download_blob.assert_called_once_with( + offset=None, + length=None, + max_concurrency=fs.max_concurrency, + version_id=None, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + + +async def test_get_file_timeout(storage, mocker, tmp_path): + from azure.storage.blob.aio import BlobClient + + fs = AzureBlobFileSystem( + account_name=storage.account_name, + connection_string=CONN_STR, + skip_instance_cache=True, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + download_blob = mocker.patch.object(BlobClient, "download_blob") + + await fs._get_file("data/root/a/file.txt", str(tmp_path / "out")) + download_blob.assert_called_once_with( + offset=None, + length=None, + max_concurrency=fs.max_concurrency, + version_id=None, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + + +async def test_pipe_file_timeout(storage, mocker): + from azure.storage.blob.aio import BlobClient + + fs = AzureBlobFileSystem( + account_name=storage.account_name, + connection_string=CONN_STR, + skip_instance_cache=True, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + upload_blob = mocker.patch.object(BlobClient, "upload_blob") + + await fs._pipe_file("pipefiletimeout", b"data") + upload_blob.assert_called_once_with( + offset=None, + length=None, + max_concurrency=fs.max_concurrency, + version_id=None, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + + +async def test_put_file_timeout(storage, mocker, tmp_path): + from azure.storage.blob.aio import BlobClient + + fs = AzureBlobFileSystem( + account_name=storage.account_name, + connection_string=CONN_STR, + skip_instance_cache=True, + timeout=11, + connection_timeout=12, + read_timeout=13, + ) + upload_blob = mocker.patch.object(BlobClient, "upload_blob") + + src = tmp_path / "putfiletimeout" + src.write_bytes(b"data") + + await fs._pipe_file(str(src), "putfiletimeout") + upload_blob.assert_called_once_with( + offset=None, + length=None, + max_concurrency=fs.max_concurrency, + version_id=None, + timeout=11, + connection_timeout=12, + read_timeout=13, + )