Skip to content

Commit

Permalink
Merge pull request #548 from laughingman7743/#547
Browse files Browse the repository at this point in the history
Adjusted size of the last part of a multipart request (fix #547)
  • Loading branch information
laughingman7743 committed May 26, 2024
2 parents 8c72f89 + 2e48aa2 commit ea39ffb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 31 deletions.
70 changes: 48 additions & 22 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,9 @@ def _head_object(

def _ls_buckets(self, refresh: bool = False) -> List[S3Object]:
if "" not in self.dircache or refresh:
try:
response = self._call(
self._client.list_buckets,
)
except botocore.exceptions.ClientError as e:
raise
response = self._call(
self._client.list_buckets,
)
buckets = [
S3Object(
init={
Expand Down Expand Up @@ -550,6 +547,8 @@ def cp_file(self, path1: str, path2: str, **kwargs):
bucket2, key2, version_id2 = self.parse_path(path2)
if version_id2:
raise ValueError("Cannot copy to a versioned file.")
if not key1 or not key2:
raise ValueError("Cannot copy buckets.")

info1 = self.info(path1)
size1 = info1.get("size", 0)
Expand Down Expand Up @@ -662,7 +661,7 @@ def _copy_object_with_multipart_upload(
}
)

parts.sort(key=lambda x: x["PartNumber"])
parts.sort(key=lambda x: x["PartNumber"]) # type: ignore
self._complete_multipart_upload(
bucket=bucket2,
key=key2,
Expand All @@ -677,14 +676,20 @@ def cat_file(
if start is not None or end is not None:
size = self.info(path).get("size", 0)
if start is None:
start = 0
range_start = 0
elif start < 0:
start = size + start
range_start = size + start
else:
range_start = start

if end is None:
end = size
range_end = size
elif end < 0:
end = size + end
ranges = (start, end)
range_end = size + end
else:
range_end = end

ranges = (range_start, range_end)
else:
ranges = None

Expand Down Expand Up @@ -1082,17 +1087,38 @@ def _upload_chunk(self, final: bool = False) -> bool:
part_number = len(self.multipart_upload_parts)
self.buffer.seek(0)
while data := self.buffer.read(self.blocksize):
part_number += 1
self.multipart_upload_parts.append(
self._executor.submit(
self.fs._upload_part,
bucket=self.bucket,
key=self.key,
upload_id=cast(str, self.multipart_upload.upload_id),
part_number=part_number,
body=data,
# The last part of a multipart request should be adjusted
# to be larger than the minimum part size.
next_data = self.buffer.read(self.blocksize)
next_data_size = len(next_data)
if 0 < next_data_size < self.fs.MULTIPART_UPLOAD_MIN_PART_SIZE:
upload_data = data + next_data
upload_data_size = len(upload_data)
if upload_data_size < self.fs.MULTIPART_UPLOAD_MAX_PART_SIZE:
uploads = [upload_data]
else:
split_size = upload_data_size // 2
uploads = [upload_data[:split_size], upload_data[split_size:]]
else:
uploads = [data]
if next_data:
uploads.append(next_data)

for upload in uploads:
part_number += 1
self.multipart_upload_parts.append(
self._executor.submit(
self.fs._upload_part,
bucket=self.bucket,
key=self.key,
upload_id=cast(str, self.multipart_upload.upload_id),
part_number=part_number,
body=upload,
)
)
)

if not next_data:
break

if self.autocommit and final:
self.commit()
Expand Down
38 changes: 29 additions & 9 deletions tests/pyathena/filesystem/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,18 +733,38 @@ def test_pandas_read_csv(self):
)
assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)]

def test_pandas_write_csv(self):
@pytest.mark.parametrize(
["line_count"],
[
(1 * (2**20),), # Generates files of about 2 MB.
(2 * (2**20),), # 4MB
(3 * (2**20),), # 6MB
(4 * (2**20),), # 8MB
(5 * (2**20),), # 10MB
(6 * (2**20),), # 12MB
],
)
def test_pandas_write_csv(self, line_count):
import pandas

df = pandas.DataFrame({"a": [1], "b": [2]})
path = (
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/"
f"filesystem/test_pandas_write_csv/{uuid.uuid4()}.csv"
)
df.to_csv(path, index=False)
with tempfile.NamedTemporaryFile("w") as tmp:
tmp.write("col1")
tmp.write("\n")
for i in range(0, line_count):
tmp.write("a")
tmp.write("\n")
tmp.flush()

tmp.seek(0)
df = pandas.read_csv(tmp.name)
path = (
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/"
f"filesystem/test_pandas_write_csv/{uuid.uuid4()}.csv"
)
df.to_csv(path, index=False)

actual = pandas.read_csv(path)
pandas.testing.assert_frame_equal(df, actual)
actual = pandas.read_csv(path)
pandas.testing.assert_frame_equal(actual, df)


class TestS3File:
Expand Down

0 comments on commit ea39ffb

Please sign in to comment.