Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Oct 1, 2024
2 parents 2a1e2e3 + 1d3d7a5 commit cd40b08
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 35 deletions.
47 changes: 18 additions & 29 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
if TYPE_CHECKING:
from typing import Self

import numpy.typing as npt

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import JSON, ChunkCoords
Expand All @@ -20,6 +18,7 @@

import numcodecs.abc
import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
from zarr.core.array_spec import ArraySpec
Expand Down Expand Up @@ -167,7 +166,7 @@ def __init__(
self,
*,
shape: Iterable[int],
data_type: str | np.dtype[Any] | DataType,
data_type: npt.DTypeLike | DataType,
chunk_grid: dict[str, JSON] | ChunkGrid,
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
fill_value: Any,
Expand Down Expand Up @@ -269,13 +268,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
_data["data_type"] = DataType.parse(_data.pop("data_type"))
data_type = DataType.parse(_data.pop("data_type"))

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
_data["attributes"] = _data.pop("attributes", None)
return cls(**_data) # type: ignore[arg-type]
return cls(**_data, data_type=data_type) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
out_dict = super().to_dict()
Expand Down Expand Up @@ -525,30 +524,20 @@ def from_numpy_dtype(cls, dtype: np.dtype[Any]) -> DataType:
return DataType[dtype_to_data_type[dtype.str]]

@classmethod
def parse(cls, dtype: str | np.dtype[Any] | DataType) -> DataType:
def parse(cls, dtype: None | DataType | Any) -> DataType:
if dtype is None:
# the default dtype
return DataType.float64
if isinstance(dtype, DataType):
return dtype
elif isinstance(dtype, np.dtype):
return cls.from_numpy_dtype(dtype)
elif isinstance(dtype, str):
try:
return cls(dtype)
except ValueError as e:
raise TypeError(f"Invalid V3 data_type: {dtype}") from e
else:
raise TypeError(f"Invalid V3 data_type: {dtype}")


def numpy_dtype_to_zarr_data_type(data: npt.DTypeLike) -> DataType:
try:
dtype = np.dtype(data)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
# dtype = DataType.from_dtype(dtype)
_ = DataType.from_numpy_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
try:
dtype = np.dtype(dtype)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
# check that this is a valid v3 data_type
try:
data_type = DataType.from_numpy_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
return data_type
11 changes: 5 additions & 6 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -22,7 +22,6 @@

from zarr.core.metadata.v3 import (
parse_dimension_names,
parse_dtype,
parse_fill_value,
parse_zarr_format,
)
Expand Down Expand Up @@ -209,7 +208,7 @@ def test_metadata_to_dict(
storage_transformers: None | tuple[dict[str, JSON]],
) -> None:
shape = (1, 2, 3)
data_type = "uint8"
data_type = DataType.uint8
if chunk_grid == "regular":
cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}}

Expand Down Expand Up @@ -290,7 +289,7 @@ def test_metadata_to_dict(
# assert result["fill_value"] == fill_value


async def test_invalid_dtype_raises() -> None:
def test_invalid_dtype_raises() -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
Expand All @@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None:
"codecs": (),
"fill_value": np.datetime64(0, "ns"),
}
with pytest.raises(ValueError, match=r".* is not a valid DataType"):
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
ArrayV3Metadata.from_dict(metadata_dict)


@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
def test_parse_invalid_dtype_raises(data):
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
parse_dtype(data)
DataType.parse(data)


@pytest.mark.parametrize(
Expand Down

0 comments on commit cd40b08

Please sign in to comment.