Skip to content

Commit

Permalink
ruff rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
samaloney committed Jun 21, 2024
1 parent 825dafe commit 72300aa
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 101 deletions.
90 changes: 40 additions & 50 deletions sunkit_spex/extern/ndcube/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Meta(dict):
axis-awareness. If specific pieces of metadata have a known way to behave during
rebinning, this can be handled by subclasses or mixins.
"""

def __init__(self, header=None, comments=None, axes=None, data_shape=None):
self.__ndcube_can_slice__ = True
self.__ndcube_can_rebin__ = True
Expand All @@ -89,8 +90,7 @@ def __init__(self, header=None, comments=None, axes=None, data_shape=None):
else:
comments = dict(comments)
if not set(comments.keys()).issubset(set(header_keys)):
raise ValueError(
"All comments must correspond to a value in header under the same key.")
raise ValueError("All comments must correspond to a value in header under the same key.")
self._comments = comments

if data_shape is None:
Expand All @@ -101,40 +101,42 @@ def __init__(self, header=None, comments=None, axes=None, data_shape=None):
if axes is None:
self._axes = dict()
else:
if not (isinstance(data_shape, collections.abc.Iterable) and
all([isinstance(i, numbers.Integral) for i in data_shape])):
raise TypeError("If axes is set, data_shape must be an iterable giving "
"the length of each axis of the associated cube.")
if not (
isinstance(data_shape, collections.abc.Iterable)
and all([isinstance(i, numbers.Integral) for i in data_shape])
):
raise TypeError(
"If axes is set, data_shape must be an iterable giving "
"the length of each axis of the associated cube."
)
axes = dict(axes)
if not set(axes.keys()).issubset(set(header_keys)):
raise ValueError(
"All axes must correspond to a value in header under the same key.")
self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key))
for key, axis in axes.items()])
raise ValueError("All axes must correspond to a value in header under the same key.")
self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key)) for key, axis in axes.items()])

def _sanitize_axis_value(self, axis, value, key):
axis_err_msg = ("Values in axes must be an integer or iterable of integers giving "
f"the data axis/axes associated with the metadata. axis = {axis}.")
axis_err_msg = (
"Values in axes must be an integer or iterable of integers giving "
f"the data axis/axes associated with the metadata. axis = {axis}."
)
if isinstance(axis, numbers.Integral):
axis = (axis,)
if len(axis) == 0:
return ValueError(axis_err_msg)
if self.shape is None:
raise TypeError("Meta instance does not have a shape so new metadata "
"cannot be assigned to an axis.")
raise TypeError("Meta instance does not have a shape so new metadata " "cannot be assigned to an axis.")
# Verify each entry in axes is an iterable of ints or a scalar.
if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral)
for i in axis])):
if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral) for i in axis])):
return ValueError(axis_err_msg)
axis = np.asarray(axis)
if _not_scalar(value):
axis_shape = tuple(self.shape[axis])
if not _is_grid_aligned(value, axis_shape) and not _is_axis_aligned(value, axis_shape):
raise ValueError(
f"{key} must have shape {tuple(self.shape[axis])} "
f"as its associated axes {axis}, ",
f"{key} must have shape {tuple(self.shape[axis])} " f"as its associated axes {axis}, ",
f"or same length as number of associated axes ({len(axis)}). "
f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}")
f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}",
)
return axis

@property
Expand Down Expand Up @@ -172,8 +174,7 @@ def add(self, name, value, comment=None, axis=None, overwrite=False):
If True, overwrites the entry of the name name if already present.
"""
if name in self.keys() and overwrite is not True:
raise KeyError(f"'{name}' already exists. "
"To update an existing metadata entry set overwrite=True.")
raise KeyError(f"'{name}' already exists. " "To update an existing metadata entry set overwrite=True.")
if comment is not None:
self._comments[name] = comment
if axis is not None:
Expand Down Expand Up @@ -202,7 +203,8 @@ def __setitem__(self, key, val):
f"must either have same length as number associated axes ({len(axis)}), "
f"or the same shape as associated data axes {tuple(self.shape[axis])}. "
f"val shape = {val.shape if hasattr(val, 'shape') else (len(val),)}\n"
"We recommend using the 'add' method to set values.")
"We recommend using the 'add' method to set values."
)
super().__setitem__(key, val)

def __getitem__(self, item):
Expand All @@ -222,8 +224,7 @@ def __getitem__(self, item):
if isinstance(item, (numbers.Integral, slice)):
item = [item]
naxes = len(self.shape)
item = np.array(list(item) + [slice(None)] * (naxes - len(item)),
dtype=object)
item = np.array(list(item) + [slice(None)] * (naxes - len(item)), dtype=object)

# Edit data shape and calculate which axis will be dropped.
dropped_axes = np.zeros(naxes, dtype=bool)
Expand All @@ -244,8 +245,7 @@ def __getitem__(self, item):
stop = self.shape[i] - stop
new_shape[i] = stop - start
else:
raise TypeError("Unrecognized slice type. "
"Must be an int, slice and tuple of the same.")
raise TypeError("Unrecognized slice type. " "Must be an int, slice and tuple of the same.")
kept_axes = np.invert(dropped_axes)
new_meta._data_shape = new_shape[kept_axes]

Expand All @@ -255,9 +255,7 @@ def __getitem__(self, item):
drop_key = False
if axis is not None:
# Calculate new axis indices.
new_axis = np.asarray(list(
set(axis).intersection(set(np.arange(naxes)[kept_axes]))
))
new_axis = np.asarray(list(set(axis).intersection(set(np.arange(naxes)[kept_axes]))))
if len(new_axis) == 0:
new_axis = None
else:
Expand All @@ -280,11 +278,11 @@ def __getitem__(self, item):
# Slice metadata value.
try:
new_value = value[new_item]
except:
except: # noqa: E722
# If value cannot be sliced by fancy slicing, convert it
# it to an array, slice it, and then if necessary, convert
# it back to its original type.
new_value = (np.asanyarray(value)[new_item])
new_value = np.asanyarray(value)[new_item]
if hasattr(new_value, "__len__"):
new_value = type(value)(new_value)
# If axis-aligned metadata sliced down to length 1, convert to scalar.
Expand All @@ -294,8 +292,7 @@ def __getitem__(self, item):
if drop_key:
new_meta.remove(key)
else:
new_meta.add(key, new_value, self.comments.get(key, None), new_axis,
overwrite=True)
new_meta.add(key, new_value, self.comments.get(key, None), new_axis, overwrite=True)

return new_meta

Expand All @@ -318,41 +315,34 @@ def rebin(self, rebinned_axes, new_shape):
# Sanitize input.
data_shape = self.shape
if not isinstance(rebinned_axes, set):
raise TypeError(
f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}")
raise TypeError(f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}")
if not all([isinstance(dim, numbers.Integral) for dim in rebinned_axes]):
raise ValueError("All elements of rebinned_axes must be ints.")
list_axes = list(rebinned_axes)
if min(list_axes) < 0 or max(list_axes) >= len(data_shape):
raise ValueError(
f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.")
raise ValueError(f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.")
if len(new_shape) != len(data_shape):
raise ValueError(f"new_shape must be a tuple of same length as data shape: "
f"{len(new_shape)} != {len(self.shape)}")
raise ValueError(
f"new_shape must be a tuple of same length as data shape: " f"{len(new_shape)} != {len(self.shape)}"
)
if not all([isinstance(dim, numbers.Integral) for dim in new_shape]):
raise TypeError("bin_shape must contain only integer types.")
# Remove axis-awareness from grid-aligned metadata associated with rebinned axes.
new_meta = copy.deepcopy(self)
null_set = set()
for name, axes in self.axes.items():
if (_is_grid_aligned(self[name], tuple(self.shape[axes]))
and set(axes).intersection(rebinned_axes) != null_set):
if (
_is_grid_aligned(self[name], tuple(self.shape[axes]))
and set(axes).intersection(rebinned_axes) != null_set
):
del new_meta._axes[name]
# Update data shape.
new_meta._data_shape = np.asarray(new_shape).astype(int)
return new_meta


def _not_scalar(value):
return (
(
hasattr(value, "shape")
or hasattr(value, "__len__")
)
and not
(
isinstance(value, str)
))
return (hasattr(value, "shape") or hasattr(value, "__len__")) and not (isinstance(value, str))


def _is_scalar(value):
Expand Down
77 changes: 35 additions & 42 deletions sunkit_spex/spectrum/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from astropy.modeling.tabular import Tabular1D
from astropy.utils import lazyproperty

__all__ = ['gwcs_from_array', 'SpectralAxis', 'Spectrum']
__all__ = ["gwcs_from_array", "SpectralAxis", "Spectrum"]


def gwcs_from_array(array):
Expand All @@ -18,38 +18,29 @@ def gwcs_from_array(array):
"""
orig_array = u.Quantity(array)

coord_frame = cf.CoordinateFrame(naxes=1,
axes_type=('SPECTRAL',),
axes_order=(0,))
coord_frame = cf.CoordinateFrame(naxes=1, axes_type=("SPECTRAL",), axes_order=(0,))
spec_frame = cf.SpectralFrame(unit=array.unit, axes_order=(0,))

# In order for the world_to_pixel transformation to automatically convert
# input units, the equivalencies in the lookup table have to be extended
# with spectral unit information.
SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,),
{'input_units_equivalencies': {'x0': u.spectral()}})
SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}})

forward_transform = SpectralTabular1D(np.arange(len(array)),
lookup_table=array)
forward_transform = SpectralTabular1D(np.arange(len(array)), lookup_table=array)
# If our spectral axis is in descending order, we have to flip the lookup
# table to be ascending in order for world_to_pixel to work.
if len(array) == 0 or array[-1] > array[0]:
forward_transform.inverse = SpectralTabular1D(
array, lookup_table=np.arange(len(array)))
forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array)))
else:
forward_transform.inverse = SpectralTabular1D(
array[::-1], lookup_table=np.arange(len(array))[::-1])
forward_transform.inverse = SpectralTabular1D(array[::-1], lookup_table=np.arange(len(array))[::-1])

class SpectralGWCS(GWCS):
def pixel_to_world(self, *args, **kwargs):
if orig_array.unit == '':
if orig_array.unit == "":
return u.Quantity(super().pixel_to_world_values(*args, **kwargs))
return super().pixel_to_world(*args, **kwargs).to(
orig_array.unit, equivalencies=u.spectral())
return super().pixel_to_world(*args, **kwargs).to(orig_array.unit, equivalencies=u.spectral())

tabular_gwcs = SpectralGWCS(forward_transform=forward_transform,
input_frame=coord_frame,
output_frame=spec_frame)
tabular_gwcs = SpectralGWCS(forward_transform=forward_transform, input_frame=coord_frame, output_frame=spec_frame)

# Store the intended unit from the origin input array
# tabular_gwcs._input_unit = orig_array.unit
Expand All @@ -73,12 +64,13 @@ class SpectralAxis(SpectralCoord):
_equivalent_unit = SpectralCoord._equivalent_unit + (u.pixel,)

def __new__(cls, value, *args, bin_specification="centers", **kwargs):

# Enforce pixel axes are ascending
if ((type(value) is u.quantity.Quantity) and
(value.size > 1) and
(value.unit is u.pix) and
(value[-1] <= value[0])):
if (
(type(value) is u.quantity.Quantity)
and (value.size > 1)
and (value.unit is u.pix)
and (value[-1] <= value[0])
):
raise ValueError("u.pix spectral axes should always be ascending")

# Convert to bin centers if bin edges were given, since SpectralCoord
Expand All @@ -101,10 +93,10 @@ def _edges_from_centers(centers, unit):
centers, with the two outer edges based on extrapolated centers added
to the beginning and end of the spectral axis.
"""
a = np.insert(centers, 0, 2*centers[0] - centers[1])
b = np.append(centers, 2*centers[-1] - centers[-2])
a = np.insert(centers, 0, 2 * centers[0] - centers[1])
b = np.append(centers, 2 * centers[-1] - centers[-2])
edges = (a + b) / 2
return edges*unit
return edges * unit

@staticmethod
def _centers_from_edges(edges):
Expand All @@ -119,7 +111,7 @@ def bin_edges(self):
Calculates bin edges if the spectral axis was created with centers
specified.
"""
if hasattr(self, '_bin_edges'):
if hasattr(self, "_bin_edges"):
return self._bin_edges
else:
return self._edges_from_centers(self.value, self.unit)
Expand Down Expand Up @@ -175,9 +167,9 @@ class Spectrum(NDCube):
Data Type: float64
"""

def __init__(self, data, *, uncertainty=None, spectral_axis=None,
spectral_dimension=0, mask=None, meta=None, **kwargs):

def __init__(
self, data, *, uncertainty=None, spectral_axis=None, spectral_dimension=0, mask=None, meta=None, **kwargs
):
# If the flux (data) argument is already a Spectrum (as it would
# be for internal arithmetic operations), avoid setup entirely.
if isinstance(data, Spectrum):
Expand All @@ -193,31 +185,29 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None,

# Ensure that the unit information codified in the quantity object is
# the One True Unit.
kwargs.setdefault('unit', data.unit if isinstance(data, u.Quantity)
else kwargs.get('unit'))
kwargs.setdefault("unit", data.unit if isinstance(data, u.Quantity) else kwargs.get("unit"))

# If flux and spectral axis are both specified, check that their lengths
# match or are off by one (implying the spectral axis stores bin edges)
if data is not None and spectral_axis is not None:
if spectral_axis.shape[0] == data.shape[spectral_dimension]:
bin_specification = "centers"
elif spectral_axis.shape[0] == data.shape[spectral_dimension]+1:
elif spectral_axis.shape[0] == data.shape[spectral_dimension] + 1:
bin_specification = "edges"
else:
raise ValueError(
"Spectral axis length ({}) must be the same size or one "
f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one "
"greater (if specifying bin edges) than that of the spextral"
"axis ({})".format(spectral_axis.shape[0],
data.shape[spectral_dimension]))
f"axis ({data.shape[spectral_dimension]})"
)

# Attempt to parse the spectral axis. If none is given, try instead to
# parse a given wcs. This is put into a GWCS object to
# then be used behind-the-scenes for all operations.
if spectral_axis is not None:
# Ensure that the spectral axis is an astropy Quantity
if not isinstance(spectral_axis, u.Quantity):
raise ValueError("Spectral axis must be a `Quantity` or "
"`SpectralAxis` object.")
raise ValueError("Spectral axis must be a `Quantity` or " "`SpectralAxis` object.")

# If a spectral axis is provided as an astropy Quantity, convert it
# to a SpectralAxis object.
Expand All @@ -226,12 +216,15 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None,
bin_specification = "edges"
else:
bin_specification = "centers"
self._spectral_axis = SpectralAxis(
spectral_axis,
bin_specification=bin_specification)
self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification)

wcs = gwcs_from_array(self._spectral_axis)

super().__init__(
data=data.value if isinstance(data, u.Quantity) else data,
wcs=wcs, mask=mask, meta=meta, uncertainty=uncertainty, **kwargs)
wcs=wcs,
mask=mask,
meta=meta,
uncertainty=uncertainty,
**kwargs,
)
4 changes: 2 additions & 2 deletions sunkit_spex/spectrum/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


def test_spectrum_bin_edges():
spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=np.arange(1, 12)*u.keV)
spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV)
assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV)


def test_spectrum_bin_centers():
spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV)
spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV)
assert_array_equal(spec._spectral_axis, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5] * u.keV)
Loading

0 comments on commit 72300aa

Please sign in to comment.