Skip to content

Commit

Permalink
Add unit test for download, encodings, hashing, and others (#72)
Browse files Browse the repository at this point in the history
* Add unit test for download, encodings, hashing, and others

* Fixed test and addressed review comments
  • Loading branch information
karan6181 committed Nov 13, 2022
1 parent 556db54 commit 467fc43
Show file tree
Hide file tree
Showing 14 changed files with 836 additions and 56 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ repos:
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
exclude: "(tests)"
- repo: https://github.com/PyCQA/pydocstyle
hooks:
- id: pydocstyle
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'pytest-cov>=4,<5',
'toml==0.10.2',
'yamllint==1.28.0',
'moto>=4.0,<5',
]

extra_deps['docs'] = [
Expand Down
19 changes: 13 additions & 6 deletions streaming/base/format/json/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,51 @@


class Encoding(ABC):
"""JSON types."""
"""Encoding of an object of JSON type."""

@classmethod
@abstractmethod
def is_encoded(cls, obj: Any) -> bool:
"""Get whether the given object is of this type.
Args:
obj (Any): The object.
obj (Any): Encoded object.
Returns:
bool: Whether of this type.
"""
raise NotImplementedError

@staticmethod
def _validate(data: Any, expected_type: Any) -> bool:
if not isinstance(data, expected_type):
raise AttributeError(
f'data should be of type {expected_type}, but instead, found as {type(data)}')
return True


class Str(Encoding):
"""Store str."""

@classmethod
def is_encoded(cls, obj: Any) -> bool:
return isinstance(obj, str)
return cls._validate(obj, str)


class Int(Encoding):
"""Store int."""

@classmethod
def is_encoded(cls, obj: Any) -> bool:
return isinstance(obj, int)
return cls._validate(obj, int)


class Float(Encoding):
"""Store float."""

@classmethod
def is_encoded(cls, obj: Any) -> bool:
return isinstance(obj, float)
return cls._validate(obj, float)


_encodings = {'str': Str, 'int': Int, 'float': Float}
Expand All @@ -68,7 +75,7 @@ def is_json_encoded(encoding: str, value: Any) -> bool:


def is_json_encoding(encoding: str) -> bool:
"""Get whether this is a supported encoding.
"""Get whether the given encoding is supported.
Args:
encoding (str): Encoding.
Expand Down
27 changes: 24 additions & 3 deletions streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,29 @@ def decode(self, data: bytes) -> Any:
"""
raise NotImplementedError

@staticmethod
def _validate(data: Any, expected_type: Any) -> None:
if not isinstance(data, expected_type):
raise AttributeError(
f'data should be of type {expected_type}, but instead, found as {type(data)}')


class Bytes(Encoding):
"""Store bytes (no-op encoding)."""

def encode(self, obj: Any) -> bytes:
def encode(self, obj: bytes) -> bytes:
self._validate(obj, bytes)
return obj

def decode(self, data: bytes) -> Any:
def decode(self, data: bytes) -> bytes:
return data


class Str(Encoding):
"""Store UTF-8."""

def encode(self, obj: str) -> bytes:
self._validate(obj, str)
return obj.encode('utf-8')

def decode(self, data: bytes) -> str:
Expand All @@ -73,6 +81,7 @@ class Int(Encoding):
size = 8

def encode(self, obj: int) -> bytes:
self._validate(obj, int)
return np.int64(obj).tobytes()

def decode(self, data: bytes) -> int:
Expand All @@ -86,6 +95,7 @@ class PIL(Encoding):
"""

def encode(self, obj: Image.Image) -> bytes:
self._validate(obj, Image.Image)
mode = obj.mode.encode('utf-8')
width, height = obj.size
raw = obj.tobytes()
Expand All @@ -106,6 +116,7 @@ class JPEG(Encoding):
"""Store PIL image as JPEG."""

def encode(self, obj: Image.Image) -> bytes:
self._validate(obj, Image.Image)
out = BytesIO()
obj.save(out, format='JPEG')
return out.getvalue()
Expand All @@ -119,6 +130,7 @@ class PNG(Encoding):
"""Store PIL image as PNG."""

def encode(self, obj: Image.Image) -> bytes:
self._validate(obj, Image.Image)
out = BytesIO()
obj.save(out, format='PNG')
return out.getvalue()
Expand All @@ -142,11 +154,20 @@ class JSON(Encoding):
"""Store arbitrary data as JSON."""

def encode(self, obj: Any) -> bytes:
return json.dumps(obj).encode('utf-8')
data = json.dumps(obj)
self._is_valid(obj, data)
return data.encode('utf-8')

def decode(self, data: bytes) -> Any:
return json.loads(data.decode('utf-8'))

def _is_valid(self, original: Any, converted: Any) -> None:
try:
json.loads(converted)
except json.decoder.JSONDecodeError as e:
e.msg = f'Invalid JSON data: {original}'
raise


# Encodings (name -> class).
_encodings = {
Expand Down
42 changes: 24 additions & 18 deletions streaming/base/format/xsv/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,42 @@ class Encoding(ABC):
@classmethod
@abstractmethod
def encode(cls, obj: Any) -> str:
"""Encode the object.
"""Encode the given data from the original object to string.
Args:
obj (Any): The object.
obj (Any): Decoded object.
Returns:
str: String form.
str: Encoded data in string form.
"""
raise NotImplementedError

@classmethod
@abstractmethod
def decode(cls, obj: str) -> Any:
"""Decode the object.
"""Decode the given data from string to the original object.
Args:
obj (str): String form.
obj (str): Encoded data in string form.
Returns:
Any: The object.
Any: Decoded object.
"""
raise NotImplementedError

@staticmethod
def _validate(data: Any, expected_type: Any) -> None:
if not isinstance(data, expected_type):
raise AttributeError(
f'data should be of type {expected_type}, but instead, found as {type(data)}')


class Str(Encoding):
"""Store str."""

@classmethod
def encode(cls, obj: Any) -> str:
assert isinstance(obj, str)
cls._validate(obj, str)
return obj

@classmethod
Expand All @@ -57,7 +63,7 @@ class Int(Encoding):

@classmethod
def encode(cls, obj: Any) -> str:
assert isinstance(obj, int)
cls._validate(obj, int)
return str(obj)

@classmethod
Expand All @@ -70,7 +76,7 @@ class Float(Encoding):

@classmethod
def encode(cls, obj: Any) -> str:
assert isinstance(obj, float)
cls._validate(obj, float)
return str(obj)

@classmethod
Expand All @@ -82,7 +88,7 @@ def decode(cls, obj: str) -> Any:


def is_xsv_encoding(encoding: str) -> bool:
"""Get whether this is a supported encoding.
"""Get whether the given encoding is supported.
Args:
encoding (str): Encoding.
Expand All @@ -94,28 +100,28 @@ def is_xsv_encoding(encoding: str) -> bool:


def xsv_encode(encoding: str, value: Any) -> str:
"""Encode the object.
"""Encode the given data from the original object to string.
Args:
encoding (str): The encoding.
value (Any): The object.
encoding (str): Encoding name.
value (Any): Object to encode.
Returns:
str: String form.
str: Data in string form.
"""
cls = _encodings[encoding]
return cls.encode(value)


def xsv_decode(encoding: str, value: str) -> Any:
"""Encode the object.
"""Decode the given data from string to the original object.
Args:
encoding (str): The encoding.
value (str): String form.
encoding (str): Encoding name.
value (str): Object to decode.
Returns:
Any: The object.
Any: Decoded object.
"""
cls = _encodings[encoding]
return cls.decode(value)
41 changes: 26 additions & 15 deletions tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,36 @@
# SPDX-License-Identifier: Apache-2.0

import json
import pathlib
from typing import List, Optional, Tuple
import os
import tempfile
from typing import Any, List, Optional

import pytest


@pytest.fixture
def remote_local(tmp_path: pathlib.Path) -> Tuple[str, str]:
remote = tmp_path.joinpath('remote')
local = tmp_path.joinpath('local')
return str(remote), str(local)


@pytest.fixture
def compressed_remote_local(tmp_path: pathlib.Path) -> Tuple[str, str, str]:
compressed = tmp_path.joinpath('compressed')
remote = tmp_path.joinpath('remote')
local = tmp_path.joinpath('local')
return tuple(str(x) for x in [compressed, remote, local])
@pytest.fixture(scope='function')
def remote_local() -> Any:
"""Creates a temporary directory and then deletes it when the calling function is done."""
try:
mock_dir = tempfile.TemporaryDirectory()
mock_remote_dir = os.path.join(mock_dir.name, 'remote')
mock_local_dir = os.path.join(mock_dir.name, 'local')
yield mock_remote_dir, mock_local_dir
finally:
mock_dir.cleanup() # pyright: ignore


@pytest.fixture(scope='function')
def compressed_remote_local() -> Any:
"""Creates a temporary directory and then deletes it when the calling function is done."""
try:
mock_dir = tempfile.TemporaryDirectory()
mock_compressed_dir = os.path.join(mock_dir.name, 'compressed')
mock_remote_dir = os.path.join(mock_dir.name, 'remote')
mock_local_dir = os.path.join(mock_dir.name, 'local')
yield mock_compressed_dir, mock_remote_dir, mock_local_dir
finally:
mock_dir.cleanup() # pyright: ignore


def get_config_in_bytes(format: str,
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any

import pytest
Expand All @@ -15,3 +16,19 @@ def pytest_runtest_call(item: Any):
dist_test_class = item.cls()
dist_test_class._run_test(item._request)
item.runtest = lambda: True # Dummy function so test is not run twice


@pytest.fixture(scope='session', autouse=True)
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'


@pytest.fixture(scope='session', autouse=True)
def gcs_credentials():
"""Mocked GCS Credentials for moto."""
os.environ['GCS_KEY'] = 'testing'
os.environ['GCS_SECRET'] = 'testing'
Loading

0 comments on commit 467fc43

Please sign in to comment.