From 72300aa3b9fc154711676c73f7e9c6a7722f24e3 Mon Sep 17 00:00:00 2001 From: Shane Maloney Date: Fri, 21 Jun 2024 09:32:32 +0100 Subject: [PATCH] ruff rebase --- sunkit_spex/extern/ndcube/meta.py | 90 +++++++++---------- sunkit_spex/spectrum/spectrum.py | 77 ++++++++-------- sunkit_spex/spectrum/tests/test_spectrum.py | 4 +- .../spectrum/tests/test_uncertaintiy.py | 16 ++-- 4 files changed, 86 insertions(+), 101 deletions(-) diff --git a/sunkit_spex/extern/ndcube/meta.py b/sunkit_spex/extern/ndcube/meta.py index e7ef957..b7ea6f1 100644 --- a/sunkit_spex/extern/ndcube/meta.py +++ b/sunkit_spex/extern/ndcube/meta.py @@ -72,6 +72,7 @@ class Meta(dict): axis-awareness. If specific pieces of metadata have a known way to behave during rebinning, this can be handled by subclasses or mixins. """ + def __init__(self, header=None, comments=None, axes=None, data_shape=None): self.__ndcube_can_slice__ = True self.__ndcube_can_rebin__ = True @@ -89,8 +90,7 @@ def __init__(self, header=None, comments=None, axes=None, data_shape=None): else: comments = dict(comments) if not set(comments.keys()).issubset(set(header_keys)): - raise ValueError( - "All comments must correspond to a value in header under the same key.") + raise ValueError("All comments must correspond to a value in header under the same key.") self._comments = comments if data_shape is None: @@ -101,40 +101,42 @@ def __init__(self, header=None, comments=None, axes=None, data_shape=None): if axes is None: self._axes = dict() else: - if not (isinstance(data_shape, collections.abc.Iterable) and - all([isinstance(i, numbers.Integral) for i in data_shape])): - raise TypeError("If axes is set, data_shape must be an iterable giving " - "the length of each axis of the associated cube.") + if not ( + isinstance(data_shape, collections.abc.Iterable) + and all([isinstance(i, numbers.Integral) for i in data_shape]) + ): + raise TypeError( + "If axes is set, data_shape must be an iterable giving " + "the length of each axis of the associated cube." + ) axes = dict(axes) if not set(axes.keys()).issubset(set(header_keys)): - raise ValueError( - "All axes must correspond to a value in header under the same key.") - self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key)) - for key, axis in axes.items()]) + raise ValueError("All axes must correspond to a value in header under the same key.") + self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key)) for key, axis in axes.items()]) def _sanitize_axis_value(self, axis, value, key): - axis_err_msg = ("Values in axes must be an integer or iterable of integers giving " - f"the data axis/axes associated with the metadata. axis = {axis}.") + axis_err_msg = ( + "Values in axes must be an integer or iterable of integers giving " + f"the data axis/axes associated with the metadata. axis = {axis}." + ) if isinstance(axis, numbers.Integral): axis = (axis,) if len(axis) == 0: return ValueError(axis_err_msg) if self.shape is None: - raise TypeError("Meta instance does not have a shape so new metadata " - "cannot be assigned to an axis.") + raise TypeError("Meta instance does not have a shape so new metadata " "cannot be assigned to an axis.") # Verify each entry in axes is an iterable of ints or a scalar. - if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral) - for i in axis])): + if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral) for i in axis])): return ValueError(axis_err_msg) axis = np.asarray(axis) if _not_scalar(value): axis_shape = tuple(self.shape[axis]) if not _is_grid_aligned(value, axis_shape) and not _is_axis_aligned(value, axis_shape): raise ValueError( - f"{key} must have shape {tuple(self.shape[axis])} " - f"as its associated axes {axis}, ", + f"{key} must have shape {tuple(self.shape[axis])} " f"as its associated axes {axis}, ", f"or same length as number of associated axes ({len(axis)}). " - f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}") + f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}", + ) return axis @property @@ -172,8 +174,7 @@ def add(self, name, value, comment=None, axis=None, overwrite=False): If True, overwrites the entry of the name name if already present. """ if name in self.keys() and overwrite is not True: - raise KeyError(f"'{name}' already exists. " - "To update an existing metadata entry set overwrite=True.") + raise KeyError(f"'{name}' already exists. " "To update an existing metadata entry set overwrite=True.") if comment is not None: self._comments[name] = comment if axis is not None: @@ -202,7 +203,8 @@ def __setitem__(self, key, val): f"must either have same length as number associated axes ({len(axis)}), " f"or the same shape as associated data axes {tuple(self.shape[axis])}. " f"val shape = {val.shape if hasattr(val, 'shape') else (len(val),)}\n" - "We recommend using the 'add' method to set values.") + "We recommend using the 'add' method to set values." + ) super().__setitem__(key, val) def __getitem__(self, item): @@ -222,8 +224,7 @@ def __getitem__(self, item): if isinstance(item, (numbers.Integral, slice)): item = [item] naxes = len(self.shape) - item = np.array(list(item) + [slice(None)] * (naxes - len(item)), - dtype=object) + item = np.array(list(item) + [slice(None)] * (naxes - len(item)), dtype=object) # Edit data shape and calculate which axis will be dropped. dropped_axes = np.zeros(naxes, dtype=bool) @@ -244,8 +245,7 @@ def __getitem__(self, item): stop = self.shape[i] - stop new_shape[i] = stop - start else: - raise TypeError("Unrecognized slice type. " - "Must be an int, slice and tuple of the same.") + raise TypeError("Unrecognized slice type. " "Must be an int, slice and tuple of the same.") kept_axes = np.invert(dropped_axes) new_meta._data_shape = new_shape[kept_axes] @@ -255,9 +255,7 @@ def __getitem__(self, item): drop_key = False if axis is not None: # Calculate new axis indices. - new_axis = np.asarray(list( - set(axis).intersection(set(np.arange(naxes)[kept_axes])) - )) + new_axis = np.asarray(list(set(axis).intersection(set(np.arange(naxes)[kept_axes])))) if len(new_axis) == 0: new_axis = None else: @@ -280,11 +278,11 @@ def __getitem__(self, item): # Slice metadata value. try: new_value = value[new_item] - except: + except: # noqa: E722 # If value cannot be sliced by fancy slicing, convert it # it to an array, slice it, and then if necessary, convert # it back to its original type. - new_value = (np.asanyarray(value)[new_item]) + new_value = np.asanyarray(value)[new_item] if hasattr(new_value, "__len__"): new_value = type(value)(new_value) # If axis-aligned metadata sliced down to length 1, convert to scalar. @@ -294,8 +292,7 @@ def __getitem__(self, item): if drop_key: new_meta.remove(key) else: - new_meta.add(key, new_value, self.comments.get(key, None), new_axis, - overwrite=True) + new_meta.add(key, new_value, self.comments.get(key, None), new_axis, overwrite=True) return new_meta @@ -318,25 +315,26 @@ def rebin(self, rebinned_axes, new_shape): # Sanitize input. data_shape = self.shape if not isinstance(rebinned_axes, set): - raise TypeError( - f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}") + raise TypeError(f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}") if not all([isinstance(dim, numbers.Integral) for dim in rebinned_axes]): raise ValueError("All elements of rebinned_axes must be ints.") list_axes = list(rebinned_axes) if min(list_axes) < 0 or max(list_axes) >= len(data_shape): - raise ValueError( - f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.") + raise ValueError(f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.") if len(new_shape) != len(data_shape): - raise ValueError(f"new_shape must be a tuple of same length as data shape: " - f"{len(new_shape)} != {len(self.shape)}") + raise ValueError( + f"new_shape must be a tuple of same length as data shape: " f"{len(new_shape)} != {len(self.shape)}" + ) if not all([isinstance(dim, numbers.Integral) for dim in new_shape]): raise TypeError("bin_shape must contain only integer types.") # Remove axis-awareness from grid-aligned metadata associated with rebinned axes. new_meta = copy.deepcopy(self) null_set = set() for name, axes in self.axes.items(): - if (_is_grid_aligned(self[name], tuple(self.shape[axes])) - and set(axes).intersection(rebinned_axes) != null_set): + if ( + _is_grid_aligned(self[name], tuple(self.shape[axes])) + and set(axes).intersection(rebinned_axes) != null_set + ): del new_meta._axes[name] # Update data shape. new_meta._data_shape = np.asarray(new_shape).astype(int) @@ -344,15 +342,7 @@ def rebin(self, rebinned_axes, new_shape): def _not_scalar(value): - return ( - ( - hasattr(value, "shape") - or hasattr(value, "__len__") - ) - and not - ( - isinstance(value, str) - )) + return (hasattr(value, "shape") or hasattr(value, "__len__")) and not (isinstance(value, str)) def _is_scalar(value): diff --git a/sunkit_spex/spectrum/spectrum.py b/sunkit_spex/spectrum/spectrum.py index 3d93c52..305782d 100644 --- a/sunkit_spex/spectrum/spectrum.py +++ b/sunkit_spex/spectrum/spectrum.py @@ -8,7 +8,7 @@ from astropy.modeling.tabular import Tabular1D from astropy.utils import lazyproperty -__all__ = ['gwcs_from_array', 'SpectralAxis', 'Spectrum'] +__all__ = ["gwcs_from_array", "SpectralAxis", "Spectrum"] def gwcs_from_array(array): @@ -18,38 +18,29 @@ def gwcs_from_array(array): """ orig_array = u.Quantity(array) - coord_frame = cf.CoordinateFrame(naxes=1, - axes_type=('SPECTRAL',), - axes_order=(0,)) + coord_frame = cf.CoordinateFrame(naxes=1, axes_type=("SPECTRAL",), axes_order=(0,)) spec_frame = cf.SpectralFrame(unit=array.unit, axes_order=(0,)) # In order for the world_to_pixel transformation to automatically convert # input units, the equivalencies in the lookup table have to be extended # with spectral unit information. - SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), - {'input_units_equivalencies': {'x0': u.spectral()}}) + SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}}) - forward_transform = SpectralTabular1D(np.arange(len(array)), - lookup_table=array) + forward_transform = SpectralTabular1D(np.arange(len(array)), lookup_table=array) # If our spectral axis is in descending order, we have to flip the lookup # table to be ascending in order for world_to_pixel to work. if len(array) == 0 or array[-1] > array[0]: - forward_transform.inverse = SpectralTabular1D( - array, lookup_table=np.arange(len(array))) + forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array))) else: - forward_transform.inverse = SpectralTabular1D( - array[::-1], lookup_table=np.arange(len(array))[::-1]) + forward_transform.inverse = SpectralTabular1D(array[::-1], lookup_table=np.arange(len(array))[::-1]) class SpectralGWCS(GWCS): def pixel_to_world(self, *args, **kwargs): - if orig_array.unit == '': + if orig_array.unit == "": return u.Quantity(super().pixel_to_world_values(*args, **kwargs)) - return super().pixel_to_world(*args, **kwargs).to( - orig_array.unit, equivalencies=u.spectral()) + return super().pixel_to_world(*args, **kwargs).to(orig_array.unit, equivalencies=u.spectral()) - tabular_gwcs = SpectralGWCS(forward_transform=forward_transform, - input_frame=coord_frame, - output_frame=spec_frame) + tabular_gwcs = SpectralGWCS(forward_transform=forward_transform, input_frame=coord_frame, output_frame=spec_frame) # Store the intended unit from the origin input array # tabular_gwcs._input_unit = orig_array.unit @@ -73,12 +64,13 @@ class SpectralAxis(SpectralCoord): _equivalent_unit = SpectralCoord._equivalent_unit + (u.pixel,) def __new__(cls, value, *args, bin_specification="centers", **kwargs): - # Enforce pixel axes are ascending - if ((type(value) is u.quantity.Quantity) and - (value.size > 1) and - (value.unit is u.pix) and - (value[-1] <= value[0])): + if ( + (type(value) is u.quantity.Quantity) + and (value.size > 1) + and (value.unit is u.pix) + and (value[-1] <= value[0]) + ): raise ValueError("u.pix spectral axes should always be ascending") # Convert to bin centers if bin edges were given, since SpectralCoord @@ -101,10 +93,10 @@ def _edges_from_centers(centers, unit): centers, with the two outer edges based on extrapolated centers added to the beginning and end of the spectral axis. """ - a = np.insert(centers, 0, 2*centers[0] - centers[1]) - b = np.append(centers, 2*centers[-1] - centers[-2]) + a = np.insert(centers, 0, 2 * centers[0] - centers[1]) + b = np.append(centers, 2 * centers[-1] - centers[-2]) edges = (a + b) / 2 - return edges*unit + return edges * unit @staticmethod def _centers_from_edges(edges): @@ -119,7 +111,7 @@ def bin_edges(self): Calculates bin edges if the spectral axis was created with centers specified. """ - if hasattr(self, '_bin_edges'): + if hasattr(self, "_bin_edges"): return self._bin_edges else: return self._edges_from_centers(self.value, self.unit) @@ -175,9 +167,9 @@ class Spectrum(NDCube): Data Type: float64 """ - def __init__(self, data, *, uncertainty=None, spectral_axis=None, - spectral_dimension=0, mask=None, meta=None, **kwargs): - + def __init__( + self, data, *, uncertainty=None, spectral_axis=None, spectral_dimension=0, mask=None, meta=None, **kwargs + ): # If the flux (data) argument is already a Spectrum (as it would # be for internal arithmetic operations), avoid setup entirely. if isinstance(data, Spectrum): @@ -193,22 +185,21 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None, # Ensure that the unit information codified in the quantity object is # the One True Unit. - kwargs.setdefault('unit', data.unit if isinstance(data, u.Quantity) - else kwargs.get('unit')) + kwargs.setdefault("unit", data.unit if isinstance(data, u.Quantity) else kwargs.get("unit")) # If flux and spectral axis are both specified, check that their lengths # match or are off by one (implying the spectral axis stores bin edges) if data is not None and spectral_axis is not None: if spectral_axis.shape[0] == data.shape[spectral_dimension]: bin_specification = "centers" - elif spectral_axis.shape[0] == data.shape[spectral_dimension]+1: + elif spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: bin_specification = "edges" else: raise ValueError( - "Spectral axis length ({}) must be the same size or one " + f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one " "greater (if specifying bin edges) than that of the spextral" - "axis ({})".format(spectral_axis.shape[0], - data.shape[spectral_dimension])) + f"axis ({data.shape[spectral_dimension]})" + ) # Attempt to parse the spectral axis. If none is given, try instead to # parse a given wcs. This is put into a GWCS object to @@ -216,8 +207,7 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None, if spectral_axis is not None: # Ensure that the spectral axis is an astropy Quantity if not isinstance(spectral_axis, u.Quantity): - raise ValueError("Spectral axis must be a `Quantity` or " - "`SpectralAxis` object.") + raise ValueError("Spectral axis must be a `Quantity` or " "`SpectralAxis` object.") # If a spectral axis is provided as an astropy Quantity, convert it # to a SpectralAxis object. @@ -226,12 +216,15 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None, bin_specification = "edges" else: bin_specification = "centers" - self._spectral_axis = SpectralAxis( - spectral_axis, - bin_specification=bin_specification) + self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) wcs = gwcs_from_array(self._spectral_axis) super().__init__( data=data.value if isinstance(data, u.Quantity) else data, - wcs=wcs, mask=mask, meta=meta, uncertainty=uncertainty, **kwargs) + wcs=wcs, + mask=mask, + meta=meta, + uncertainty=uncertainty, + **kwargs, + ) diff --git a/sunkit_spex/spectrum/tests/test_spectrum.py b/sunkit_spex/spectrum/tests/test_spectrum.py index 837208a..1763bb0 100644 --- a/sunkit_spex/spectrum/tests/test_spectrum.py +++ b/sunkit_spex/spectrum/tests/test_spectrum.py @@ -7,10 +7,10 @@ def test_spectrum_bin_edges(): - spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=np.arange(1, 12)*u.keV) + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV) assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV) def test_spectrum_bin_centers(): - spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV) + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV) assert_array_equal(spec._spectral_axis, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5] * u.keV) diff --git a/sunkit_spex/spectrum/tests/test_uncertaintiy.py b/sunkit_spex/spectrum/tests/test_uncertaintiy.py index 714c1ca..b6c1927 100644 --- a/sunkit_spex/spectrum/tests/test_uncertaintiy.py +++ b/sunkit_spex/spectrum/tests/test_uncertaintiy.py @@ -12,8 +12,8 @@ def test_add(): a = NDDataRef(data, uncertainty=PoissonUncertainty(uncert)) b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) aplusb = a.add(b) - assert_array_equal(aplusb.data, 2*data) - assert_array_equal(aplusb.uncertainty.array, np.sqrt(2*uncert ** 2)) + assert_array_equal(aplusb.data, 2 * data) + assert_array_equal(aplusb.uncertainty.array, np.sqrt(2 * uncert**2)) def test_subtract(): @@ -22,8 +22,8 @@ def test_subtract(): a = NDDataRef(data, uncertainty=PoissonUncertainty(uncert)) b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) aminusb = a.subtract(b) - assert_array_equal(aminusb.data, data-data) - assert_array_equal(aminusb.uncertainty.array, np.sqrt(2*uncert ** 2)) + assert_array_equal(aminusb.data, data - data) + assert_array_equal(aminusb.uncertainty.array, np.sqrt(2 * uncert**2)) def test_multiply(): @@ -33,7 +33,7 @@ def test_multiply(): b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) atimesb = a.multiply(b) assert_array_equal(atimesb.data, data**2) - assert_array_equal(atimesb.uncertainty.array, np.sqrt((2*data**2*uncert**2))) # (b**2*da**2 + a**2db**2)**0.5 + assert_array_equal(atimesb.uncertainty.array, np.sqrt(2 * data**2 * uncert**2)) # (b**2*da**2 + a**2db**2)**0.5 def test_divide(): @@ -42,5 +42,7 @@ def test_divide(): a = NDDataRef(data, uncertainty=PoissonUncertainty(uncert)) b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) adivb = a.divide(b) - assert_array_equal(adivb.data, data/data) - assert_array_equal(adivb.uncertainty.array, np.sqrt(((1/data)**2 * uncert**2)*2)) # ((1/b)**2*da**2 + (a/b**2)**2db**2)**0.5 + assert_array_equal(adivb.data, data / data) + assert_array_equal( + adivb.uncertainty.array, np.sqrt(((1 / data) ** 2 * uncert**2) * 2) + ) # ((1/b)**2*da**2 + (a/b**2)**2db**2)**0.5