Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add enum support #60

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/data/testEnums.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<netcdf xmlns="http://www.unidata.ucar.edu/namespaces/netcdf/ncml-2.2">
<enumTypedef name="boolean" type="enum1">
<enum key="0">false</enum>
<enum key="1">true</enum>
</enumTypedef>
<variable name="be_or_not_to_be" shape="" type="enum1" typedef="boolean">
</variable>
</netcdf>
6 changes: 6 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ def test_read_meta_data():
assert ds.variables['T'].attrs['units'] == 'degC'


def test_read_enum():
ds = xncml.open_ncml(data / 'testEnums.xml')
assert ds['be_or_not_to_be'].attrs['flag_values'] == [0, 1]
assert ds['be_or_not_to_be'].attrs['flag_meanings'] == ['false', 'true']


# --- #
def check_dimension(ds):
assert len(ds['lat']) == 3
Expand Down
54 changes: 47 additions & 7 deletions xncml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,17 +263,18 @@ def read_group(target: xr.Dataset, ref: xr.Dataset, obj: Group | Netcdf) -> xr.D
Dataset holding variables and attributes defined in <netcdf> element.
"""
dims = {}
enums = {}
for item in obj.choice:
if isinstance(item, Dimension):
dims[item.name] = read_dimension(item)
elif isinstance(item, Variable):
target = read_variable(target, ref, item, dims)
target = read_variable(target, ref, item, dims, enums)
elif isinstance(item, Attribute):
read_attribute(target, item, ref)
elif isinstance(item, Remove):
target = read_remove(target, item)
elif isinstance(item, EnumTypedef):
raise NotImplementedError
enums[item.name] = read_enum(item)
elif isinstance(item, Group):
target = read_group(target, ref, item)
elif isinstance(item, Aggregation):
Expand Down Expand Up @@ -376,7 +377,37 @@ def read_coord_value(nc: Netcdf, agg: Aggregation, dtypes: list = ()):
return typ(coord)


def read_variable(target: xr.Dataset, ref: xr.Dataset, obj: Variable, dimensions: dict):
def read_enum(obj: EnumTypedef) -> dict[str, list]:
"""
Parse <enumTypeDef> element.

Example
-------
<enumTypedef name="trilean" type="enum1">
<enum key="0">false</enum>
<enum key="1">true</enum>
<enum key="2">undefined</enum>
</enumTypedef>

Parameters
----------
obj: EnumTypeDef
<enumTypeDef> object.

Returns
-------
dict:
A dictionary with CF flag_values and flag_meanings that describe the Enum.
"""
return {
'flag_values': list(map(lambda e: e.key, obj.enum)),
'flag_meanings': list(map(lambda e: e.content[0], obj.enum)),
}


def read_variable(
target: xr.Dataset, ref: xr.Dataset, obj: Variable, dimensions: dict, enums: dict
):
"""
Parse <variable> element.

Expand All @@ -390,6 +421,7 @@ def read_variable(target: xr.Dataset, ref: xr.Dataset, obj: Variable, dimensions
<variable> object description.
dimensions : dict
Dimension attributes keyed by name.
enums: dict[str, dict]

Returns
-------
Expand Down Expand Up @@ -423,6 +455,9 @@ def read_variable(target: xr.Dataset, ref: xr.Dataset, obj: Variable, dimensions
dims = obj.shape.split(' ')
shape = [dimensions[dim].length for dim in dims]
out = xr.Variable(data=np.empty(shape, dtype=nctype(obj.type)), dims=dims)
elif obj.shape == '':
# scalar variable
out = xr.Variable(data=None, dims=())
else:
raise ValueError

Expand All @@ -447,7 +482,12 @@ def read_variable(target: xr.Dataset, ref: xr.Dataset, obj: Variable, dimensions
if obj.logical_reduce:
raise NotImplementedError

if obj.typedef:
if obj.typedef in enums.keys():
# TODO: Also update encoding when https://github.com/pydata/xarray/pull/8147
# is merged in xarray.
out.attrs['flag_values'] = enums[obj.typedef]['flag_values']
out.attrs['flag_meanings'] = enums[obj.typedef]['flag_meanings']
elif obj.typedef is not None:
raise NotImplementedError

target[obj.name] = out
Expand Down Expand Up @@ -551,11 +591,11 @@ def nctype(typ: DataType) -> type:

if typ in [DataType.STRING, DataType.STRING_1]:
return str
elif typ == DataType.BYTE:
elif typ in [DataType.BYTE, DataType.ENUM1]:
return np.int8
elif typ == DataType.SHORT:
elif typ in [DataType.SHORT, DataType.ENUM2]:
return np.int16
elif typ == DataType.INT:
elif typ in [DataType.INT, DataType.ENUM4]:
return np.int32
elif typ == DataType.LONG:
return int
Expand Down