From 825dafee729647e8ac5326dd1574d726d9df30ea Mon Sep 17 00:00:00 2001 From: Shane Maloney Date: Mon, 15 Apr 2024 14:43:53 +0100 Subject: [PATCH 1/2] Add Spectrum and Poisson uncertaintiy --- setup.cfg | 21 +- sunkit_spex/extern/ndcube/__init__.py | 0 sunkit_spex/extern/ndcube/meta.py | 373 ++++++++++++++++++ sunkit_spex/legacy/tests/test_brem.py | 4 +- sunkit_spex/legacy/tests/test_integrate.py | 2 +- .../legacy/tests/test_photon_power_law.py | 2 +- .../{test_thermal.py => test_thermal_.py} | 2 +- sunkit_spex/spectrum/__init__.py | 3 + sunkit_spex/spectrum/spectrum.py | 237 +++++++++++ sunkit_spex/spectrum/tests/__init__.py | 0 sunkit_spex/spectrum/tests/test_spectrum.py | 16 + .../spectrum/tests/test_uncertaintiy.py | 46 +++ sunkit_spex/spectrum/uncertainty.py | 134 +++++++ 13 files changed, 825 insertions(+), 15 deletions(-) create mode 100644 sunkit_spex/extern/ndcube/__init__.py create mode 100644 sunkit_spex/extern/ndcube/meta.py rename sunkit_spex/legacy/tests/{test_thermal.py => test_thermal_.py} (99%) create mode 100644 sunkit_spex/spectrum/__init__.py create mode 100644 sunkit_spex/spectrum/spectrum.py create mode 100644 sunkit_spex/spectrum/tests/__init__.py create mode 100644 sunkit_spex/spectrum/tests/test_spectrum.py create mode 100644 sunkit_spex/spectrum/tests/test_uncertaintiy.py create mode 100644 sunkit_spex/spectrum/uncertainty.py diff --git a/setup.cfg b/setup.cfg index 811e2078..714dd8f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,19 +15,19 @@ packages = find: python_requires = >=3.9 setup_requires = setuptools_scm install_requires = - sunpy - parfive - scipy - xarray - quadpy - orthopy - ndim - matplotlib - emcee corner + emcee + matplotlib + ndcube + ndim nestle numdifftools - + orthopy + parfive + quadpy + scipy + sunpy + xarray [options.extras_require] test = @@ -35,6 +35,7 @@ test = pytest-astropy pytest-cov pytest-xdist + docs = sphinx sphinx-automodapi diff --git a/sunkit_spex/extern/ndcube/__init__.py b/sunkit_spex/extern/ndcube/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sunkit_spex/extern/ndcube/meta.py b/sunkit_spex/extern/ndcube/meta.py new file mode 100644 index 00000000..e7ef9573 --- /dev/null +++ b/sunkit_spex/extern/ndcube/meta.py @@ -0,0 +1,373 @@ +import copy +import numbers +import collections.abc + +import numpy as np + +__all__ = ["Meta"] + + +class Meta(dict): + """ + A sliceable object for storing metadata. + + Metadata can be linked to a data axis which causes it to be sliced when the + standard Python numeric slicing API is applied to the object. + Specific pieces of metadata can be obtain using the dict-like string slicing API. + Metadata associated with an axis/axes must have the same length/shape as those axes. + + Parameters + ---------- + header: dict-like + The names and values of metadata. + + comments: dict-like, optional + Comments associated with any of the above pieces of metadata. + + axes: dict-like, optional + The axis/axes associated with the metadata denoted by the keys. + Metadata not included are considered not to be associated with any axis. + Each axis value must be an iterable of `int`. An `int` itself is also + acceptable if the metadata is associated with a single axis. + The value of axis-assigned metadata in header must be same length as + number of associated axes (axis-aligned), or same shape as the associated + data array's axes (grid-aligned). + + data_shape: iterator of `int`, optional + The shape of the data with which this metadata is associated. + Must be set if axes input is set. + + Notes + ----- + **Axis-aware Metadata** + There are two valid types of axis-aware metadata: axis-aligned and grid-aligned. + Axis-aligned metadata gives one value per associated axis, while grid-aligned + metadata gives a value for each data array element in the associated axes. + Consequently, axis-aligned metadata has the same length as the number of + associated axes, while grid-aligned metadata has the same shape as the associated + axes. To avoid confusion, axis-aligned metadata that is only associated with one + axis must be scalar or a string. Length-1 objects (excluding strings) are assumed + to be grid-aligned and associated with a length-1 axis. + + **Slicing and Rebinning Axis-aware Metadata** + Axis-aligned metadata is only considered valid if the associated axes are present. + Therefore, axis-aligned metadata is only changed if an associated axis is dropped + by an operation, e.g. slicing. In such a case, the value associated with the + dropped axes is also dropped and hence lost. If the axis of a 1-axis-aligned + metadata value (scalar) is slicing away, the metadata key is entirely removed + from the Meta object. + + Grid-aligned metadata is mirrors the data array, it is sliced following + the same rules with one exception. If an axis is dropped by slicing, the metadata + name is kept, but its value is set to the value at the row/column where the + axis/axes was sliced away, and the metadata axis-awareness is removed. This is + similar to how coordinate values are transferred to ``global_coords`` when their + associated axes are sliced away. + + Note that because rebinning does not drop axes, axis-aligned metadata is unaltered + by rebinning. By contrast, grid-aligned metadata must necessarily by affected by + rebinning. However, how it is affected depends on the nature of the metadata and + there is no generalized solution. Therefore, this class does not alter the shape + or values of grid-aligned metadata during rebinning, but simply removes its + axis-awareness. If specific pieces of metadata have a known way to behave during + rebinning, this can be handled by subclasses or mixins. + """ + def __init__(self, header=None, comments=None, axes=None, data_shape=None): + self.__ndcube_can_slice__ = True + self.__ndcube_can_rebin__ = True + self.original_header = header + + if header is None: + header = {} + else: + header = dict(header) + super().__init__(header.items()) + header_keys = header.keys() + + if comments is None: + self._comments = dict() + else: + comments = dict(comments) + if not set(comments.keys()).issubset(set(header_keys)): + raise ValueError( + "All comments must correspond to a value in header under the same key.") + self._comments = comments + + if data_shape is None: + self._data_shape = data_shape + else: + self._data_shape = np.asarray(data_shape, dtype=int) + + if axes is None: + self._axes = dict() + else: + if not (isinstance(data_shape, collections.abc.Iterable) and + all([isinstance(i, numbers.Integral) for i in data_shape])): + raise TypeError("If axes is set, data_shape must be an iterable giving " + "the length of each axis of the associated cube.") + axes = dict(axes) + if not set(axes.keys()).issubset(set(header_keys)): + raise ValueError( + "All axes must correspond to a value in header under the same key.") + self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key)) + for key, axis in axes.items()]) + + def _sanitize_axis_value(self, axis, value, key): + axis_err_msg = ("Values in axes must be an integer or iterable of integers giving " + f"the data axis/axes associated with the metadata. axis = {axis}.") + if isinstance(axis, numbers.Integral): + axis = (axis,) + if len(axis) == 0: + return ValueError(axis_err_msg) + if self.shape is None: + raise TypeError("Meta instance does not have a shape so new metadata " + "cannot be assigned to an axis.") + # Verify each entry in axes is an iterable of ints or a scalar. + if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral) + for i in axis])): + return ValueError(axis_err_msg) + axis = np.asarray(axis) + if _not_scalar(value): + axis_shape = tuple(self.shape[axis]) + if not _is_grid_aligned(value, axis_shape) and not _is_axis_aligned(value, axis_shape): + raise ValueError( + f"{key} must have shape {tuple(self.shape[axis])} " + f"as its associated axes {axis}, ", + f"or same length as number of associated axes ({len(axis)}). " + f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}") + return axis + + @property + def comments(self): + return self._comments + + @property + def axes(self): + return self._axes + + @property + def shape(self): + return self._data_shape + + def add(self, name, value, comment=None, axis=None, overwrite=False): + """Add a new piece of metadata to instance. + + Parameters + ---------- + name: `str` + The name/label of the metadata. + + value: Any + The value of the metadata. If axes input is not None, this must have the + same length/shape as those axes as defined by ``self.shape``. + + comment: `str` or `None` + Any comment associated with this metadata. Set to None if no comment desired. + + axis: `int`, iterable of `int`, or `None` + The axis/axes with which the metadata is linked. If not associated with any + axis, set this to None. + + overwrite: `bool`, optional + If True, overwrites the entry of the name name if already present. + """ + if name in self.keys() and overwrite is not True: + raise KeyError(f"'{name}' already exists. " + "To update an existing metadata entry set overwrite=True.") + if comment is not None: + self._comments[name] = comment + if axis is not None: + axis = self._sanitize_axis_value(axis, value, name) + self._axes[name] = axis + elif name in self._axes: + del self._axes[name] + # This must be done after updating self._axes otherwise it may error. + self.__setitem__(name, value) + + def remove(self, name): + if name in self._comments: + del self._comments[name] + if name in self._axes: + del self._axes[name] + del self[name] + + def __setitem__(self, key, val): + axis = self.axes.get(key, None) + if axis is not None: + if _not_scalar(val): + axis_shape = tuple(self.shape[axis]) + if not _is_grid_aligned(val, axis_shape) and not _is_axis_aligned(val, axis_shape): + raise TypeError( + f"{key} is already associated with axis/axes {axis}. val must therefore " + f"must either have same length as number associated axes ({len(axis)}), " + f"or the same shape as associated data axes {tuple(self.shape[axis])}. " + f"val shape = {val.shape if hasattr(val, 'shape') else (len(val),)}\n" + "We recommend using the 'add' method to set values.") + super().__setitem__(key, val) + + def __getitem__(self, item): + # There are two ways to slice: + # by key, or + # by typical python numeric slicing API, + # i.e. slice the each piece of metadata associated with an axes. + + if isinstance(item, str): + return super().__getitem__(item) + + elif self.shape is None: + raise TypeError("Meta object does not have a shape and so cannot be sliced.") + + else: + new_meta = copy.deepcopy(self) + if isinstance(item, (numbers.Integral, slice)): + item = [item] + naxes = len(self.shape) + item = np.array(list(item) + [slice(None)] * (naxes - len(item)), + dtype=object) + + # Edit data shape and calculate which axis will be dropped. + dropped_axes = np.zeros(naxes, dtype=bool) + new_shape = new_meta.shape + for i, axis_item in enumerate(item): + if isinstance(axis_item, numbers.Integral): + dropped_axes[i] = True + elif isinstance(axis_item, slice): + start = axis_item.start + if start is None: + start = 0 + if start < 0: + start = self.shape[i] - start + stop = axis_item.stop + if stop is None: + stop = self.shape[i] + if stop < 0: + stop = self.shape[i] - stop + new_shape[i] = stop - start + else: + raise TypeError("Unrecognized slice type. " + "Must be an int, slice and tuple of the same.") + kept_axes = np.invert(dropped_axes) + new_meta._data_shape = new_shape[kept_axes] + + # Slice all metadata associated with axes. + for key, value in self.items(): + axis = self.axes.get(key, None) + drop_key = False + if axis is not None: + # Calculate new axis indices. + new_axis = np.asarray(list( + set(axis).intersection(set(np.arange(naxes)[kept_axes])) + )) + if len(new_axis) == 0: + new_axis = None + else: + cumul_dropped_axes = np.cumsum(dropped_axes)[new_axis] + new_axis -= cumul_dropped_axes + + # Calculate sliced metadata values. + axis_shape = tuple(self.shape[axis]) + if _is_scalar(value): + new_value = value + # If scalar metadata's axes have been dropped, mark metadata to be dropped. + if new_axis is None: + drop_key = True + else: + value_is_axis_aligned = _is_axis_aligned(value, axis_shape) + if value_is_axis_aligned: + new_item = kept_axes[axis] + else: + new_item = tuple(item[axis]) + # Slice metadata value. + try: + new_value = value[new_item] + except: + # If value cannot be sliced by fancy slicing, convert it + # it to an array, slice it, and then if necessary, convert + # it back to its original type. + new_value = (np.asanyarray(value)[new_item]) + if hasattr(new_value, "__len__"): + new_value = type(value)(new_value) + # If axis-aligned metadata sliced down to length 1, convert to scalar. + if value_is_axis_aligned and len(new_value) == 1: + new_value = new_value[0] + # Overwrite metadata value with newly sliced version. + if drop_key: + new_meta.remove(key) + else: + new_meta.add(key, new_value, self.comments.get(key, None), new_axis, + overwrite=True) + + return new_meta + + def rebin(self, rebinned_axes, new_shape): + """ + Adjusts axis-aware metadata to stay consistent with a rebinned `~ndcube.NDCube`. + + This is done by simply removing the axis-awareness of metadata associated with + rebinned axes. The metadata itself is not changed or removed. This operation + does not remove axis-awareness from metadata only associated with non-rebinned + axes, i.e. axes whose corresponding entries in ``bin_shape`` are 1. + + Parameters + ---------- + rebinned_axes: `set` of `int` + Set of array indices of axes that are rebinned. + new_shape: `tuple` of `int` + The new shape of the rebinned data. + """ + # Sanitize input. + data_shape = self.shape + if not isinstance(rebinned_axes, set): + raise TypeError( + f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}") + if not all([isinstance(dim, numbers.Integral) for dim in rebinned_axes]): + raise ValueError("All elements of rebinned_axes must be ints.") + list_axes = list(rebinned_axes) + if min(list_axes) < 0 or max(list_axes) >= len(data_shape): + raise ValueError( + f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.") + if len(new_shape) != len(data_shape): + raise ValueError(f"new_shape must be a tuple of same length as data shape: " + f"{len(new_shape)} != {len(self.shape)}") + if not all([isinstance(dim, numbers.Integral) for dim in new_shape]): + raise TypeError("bin_shape must contain only integer types.") + # Remove axis-awareness from grid-aligned metadata associated with rebinned axes. + new_meta = copy.deepcopy(self) + null_set = set() + for name, axes in self.axes.items(): + if (_is_grid_aligned(self[name], tuple(self.shape[axes])) + and set(axes).intersection(rebinned_axes) != null_set): + del new_meta._axes[name] + # Update data shape. + new_meta._data_shape = np.asarray(new_shape).astype(int) + return new_meta + + +def _not_scalar(value): + return ( + ( + hasattr(value, "shape") + or hasattr(value, "__len__") + ) + and not + ( + isinstance(value, str) + )) + + +def _is_scalar(value): + return not _not_scalar(value) + + +def _is_grid_aligned(value, axis_shape): + if _is_scalar(value): + return False + value_shape = value.shape if hasattr(value, "shape") else (len(value),) + if value_shape != axis_shape: + return False + return True + + +def _is_axis_aligned(value, axis_shape): + len_value = len(value) if _not_scalar(value) else 1 + return not _is_grid_aligned(value, axis_shape) and len_value == len(axis_shape) diff --git a/sunkit_spex/legacy/tests/test_brem.py b/sunkit_spex/legacy/tests/test_brem.py index 9b60ec43..b099ab1b 100644 --- a/sunkit_spex/legacy/tests/test_brem.py +++ b/sunkit_spex/legacy/tests/test_brem.py @@ -2,8 +2,8 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal -from sunkit_spex import emission -from sunkit_spex.integrate import fixed_quad, gauss_legendre +from sunkit_spex.legacy import emission +from sunkit_spex.legacy.integrate import fixed_quad, gauss_legendre def test_broken_power_law_electron_distribution(): diff --git a/sunkit_spex/legacy/tests/test_integrate.py b/sunkit_spex/legacy/tests/test_integrate.py index cea9b2ea..9990bb4f 100644 --- a/sunkit_spex/legacy/tests/test_integrate.py +++ b/sunkit_spex/legacy/tests/test_integrate.py @@ -1,7 +1,7 @@ import numpy as np from numpy.testing import assert_allclose -from sunkit_spex.integrate import fixed_quad, gauss_legendre +from sunkit_spex.legacy.integrate import fixed_quad, gauss_legendre def test_scalar(): diff --git a/sunkit_spex/legacy/tests/test_photon_power_law.py b/sunkit_spex/legacy/tests/test_photon_power_law.py index f7ca17f7..b08fb4b9 100644 --- a/sunkit_spex/legacy/tests/test_photon_power_law.py +++ b/sunkit_spex/legacy/tests/test_photon_power_law.py @@ -3,7 +3,7 @@ import astropy.units as u -from sunkit_spex import photon_power_law as ppl +from sunkit_spex.legacy import photon_power_law as ppl def test_different_bins(): diff --git a/sunkit_spex/legacy/tests/test_thermal.py b/sunkit_spex/legacy/tests/test_thermal_.py similarity index 99% rename from sunkit_spex/legacy/tests/test_thermal.py rename to sunkit_spex/legacy/tests/test_thermal_.py index 3de91857..f3dfc8d1 100644 --- a/sunkit_spex/legacy/tests/test_thermal.py +++ b/sunkit_spex/legacy/tests/test_thermal_.py @@ -5,7 +5,7 @@ import astropy.units as u -from sunkit_spex import thermal +from sunkit_spex.legacy import thermal # Manually load file that was used to compile expected flux values. thermal.setup_continuum_parameters( diff --git a/sunkit_spex/spectrum/__init__.py b/sunkit_spex/spectrum/__init__.py new file mode 100644 index 00000000..70628198 --- /dev/null +++ b/sunkit_spex/spectrum/__init__.py @@ -0,0 +1,3 @@ +from sunkit_spex.spectrum.spectrum import Spectrum + +__all__ = ["Spectrum"] diff --git a/sunkit_spex/spectrum/spectrum.py b/sunkit_spex/spectrum/spectrum.py new file mode 100644 index 00000000..3d93c522 --- /dev/null +++ b/sunkit_spex/spectrum/spectrum.py @@ -0,0 +1,237 @@ +import numpy as np +from gwcs import WCS as GWCS +from gwcs import coordinate_frames as cf +from ndcube import NDCube + +import astropy.units as u +from astropy.coordinates import SpectralCoord +from astropy.modeling.tabular import Tabular1D +from astropy.utils import lazyproperty + +__all__ = ['gwcs_from_array', 'SpectralAxis', 'Spectrum'] + + +def gwcs_from_array(array): + """ + Create a new WCS from provided tabular data. This defaults to being + a GWCS object. + """ + orig_array = u.Quantity(array) + + coord_frame = cf.CoordinateFrame(naxes=1, + axes_type=('SPECTRAL',), + axes_order=(0,)) + spec_frame = cf.SpectralFrame(unit=array.unit, axes_order=(0,)) + + # In order for the world_to_pixel transformation to automatically convert + # input units, the equivalencies in the lookup table have to be extended + # with spectral unit information. + SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), + {'input_units_equivalencies': {'x0': u.spectral()}}) + + forward_transform = SpectralTabular1D(np.arange(len(array)), + lookup_table=array) + # If our spectral axis is in descending order, we have to flip the lookup + # table to be ascending in order for world_to_pixel to work. + if len(array) == 0 or array[-1] > array[0]: + forward_transform.inverse = SpectralTabular1D( + array, lookup_table=np.arange(len(array))) + else: + forward_transform.inverse = SpectralTabular1D( + array[::-1], lookup_table=np.arange(len(array))[::-1]) + + class SpectralGWCS(GWCS): + def pixel_to_world(self, *args, **kwargs): + if orig_array.unit == '': + return u.Quantity(super().pixel_to_world_values(*args, **kwargs)) + return super().pixel_to_world(*args, **kwargs).to( + orig_array.unit, equivalencies=u.spectral()) + + tabular_gwcs = SpectralGWCS(forward_transform=forward_transform, + input_frame=coord_frame, + output_frame=spec_frame) + + # Store the intended unit from the origin input array + # tabular_gwcs._input_unit = orig_array.unit + + return tabular_gwcs + + +class SpectralAxis(SpectralCoord): + """ + Coordinate object representing spectral values corresponding to a specific + spectrum. Overloads SpectralCoord with additional information (currently + only bin edges). + + Parameters + ---------- + bin_specification: str, optional + Must be "edges" or "centers". Determines whether specified axis values + are interpreted as bin edges or bin centers. Defaults to "centers". + """ + + _equivalent_unit = SpectralCoord._equivalent_unit + (u.pixel,) + + def __new__(cls, value, *args, bin_specification="centers", **kwargs): + + # Enforce pixel axes are ascending + if ((type(value) is u.quantity.Quantity) and + (value.size > 1) and + (value.unit is u.pix) and + (value[-1] <= value[0])): + raise ValueError("u.pix spectral axes should always be ascending") + + # Convert to bin centers if bin edges were given, since SpectralCoord + # only accepts centers + if bin_specification == "edges": + bin_edges = value + value = SpectralAxis._centers_from_edges(value) + + obj = super().__new__(cls, value, *args, **kwargs) + + if bin_specification == "edges": + obj._bin_edges = bin_edges + + return obj + + @staticmethod + def _edges_from_centers(centers, unit): + """ + Calculates interior bin edges based on the average of each pair of + centers, with the two outer edges based on extrapolated centers added + to the beginning and end of the spectral axis. + """ + a = np.insert(centers, 0, 2*centers[0] - centers[1]) + b = np.append(centers, 2*centers[-1] - centers[-2]) + edges = (a + b) / 2 + return edges*unit + + @staticmethod + def _centers_from_edges(edges): + """ + Calculates the bin centers as the average of each pair of edges + """ + return (edges[1:] + edges[:-1]) / 2 + + @lazyproperty + def bin_edges(self): + """ + Calculates bin edges if the spectral axis was created with centers + specified. + """ + if hasattr(self, '_bin_edges'): + return self._bin_edges + else: + return self._edges_from_centers(self.value, self.unit) + + +class Spectrum(NDCube): + r""" + Spectrum container for data with one spectral axis. + + Note that "1D" in this case refers to the fact that there is only one + spectral axis. `Spectrum` can contain "vector 1D spectra" by having the + ``flux`` have a shape with dimension greater than 1. + + Notes + ----- + A stripped down version of `Spectrum1D` from `specutils`. + + Parameters + ---------- + data : `~astropy.units.Quantity` + The data for this spectrum. This can be a simple `~astropy.units.Quantity`, + or an existing `~Spectrum1D` or `~ndcube.NDCube` object. + uncertainty : `~astropy.nddata.NDUncertainty` + Contains uncertainty information along with propagation rules for + spectrum arithmetic. Can take a unit, but if none is given, will use + the unit defined in the flux. + spectral_axis : `~astropy.units.Quantity` or `~specutils.SpectralAxis` + Dispersion information with the same shape as the dimension specified by spectral_dimension + of shape plus one if specifying bin edges. + spectral_dimension : `int` default 0 + The dimension of the data which represents the spectral information default to first dimension index 0. + mask : `~numpy.ndarray`-like + Array where values in the flux to be masked are those that + ``astype(bool)`` converts to True. (For example, integer arrays are not + masked where they are 0, and masked for any other value.) + meta : dict + Arbitrary container for any user-specific information to be carried + around with the spectrum container object. + + Examples + -------- + >>> import numpy as np + >>> import astropy.units as u + >>> from sunkit_spex.spectrum import Spectrum + >>> spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=np.arange(1, 12)*u.keV) + >>> spec + >> from astropy.nddata import NDData + >>> from sunkit_spex.spectrum.uncertainty import PoissonUncertainty + >>> ndd = NDData([1,2,3], unit='m', + ... uncertainty=PoissonUncertainty([0.1, 0.1, 0.1])) + >>> ndd.uncertainty # doctest: +FLOAT_CMP + PoissonUncertainty([0.1, 0.1, 0.1]) + + or by setting it manually on the `NDData` instance:: + + >>> ndd.uncertainty = PoissonUncertainty([0.2], unit='m', copy=True) + >>> ndd.uncertainty # doctest: +FLOAT_CMP + PoissonUncertainty([0.2]) + + the uncertainty ``array`` can also be set directly:: + + >>> ndd.uncertainty.array = 2 + >>> ndd.uncertainty + PoissonUncertainty(2) + + .. note:: + The unit will not be displayed. + """ + + @property + def supports_correlated(self): + """`True` : `StdDevUncertainty` allows to propagate correlated \ + uncertainties. + + ``correlation`` must be given, this class does not implement computing + it by itself. + """ + return True + + @property + def uncertainty_type(self): + """``"poisson"`` : `PoissonUncertainty` implements Poisson uncertainty.""" + return "poisson" + + def _convert_uncertainty(self, other_uncert): + if isinstance(other_uncert, PoissonUncertainty): + return other_uncert + else: + raise IncompatibleUncertaintiesException + + def _propagate_add(self, other_uncert, result_data, correlation): + return super()._propagate_add_sub( + other_uncert, + result_data, + correlation, + subtract=False, + to_variance=np.square, + from_variance=np.sqrt, + ) + + def _propagate_subtract(self, other_uncert, result_data, correlation): + return super()._propagate_add_sub( + other_uncert, + result_data, + correlation, + subtract=True, + to_variance=np.square, + from_variance=np.sqrt, + ) + + def _propagate_multiply(self, other_uncert, result_data, correlation): + return super()._propagate_multiply_divide( + other_uncert, + result_data, + correlation, + divide=False, + to_variance=np.square, + from_variance=np.sqrt, + ) + + def _propagate_divide(self, other_uncert, result_data, correlation): + return super()._propagate_multiply_divide( + other_uncert, + result_data, + correlation, + divide=True, + to_variance=np.square, + from_variance=np.sqrt, + ) + + def _propagate_collapse(self, numpy_operation, axis): + # defer to _VariancePropagationMixin + return super()._propagate_collapse(numpy_operation, axis=axis) + + def _data_unit_to_uncertainty_unit(self, value): + return value + + def _convert_to_variance(self): + new_array = None if self.array is None else self.array**2 + new_unit = None if self.unit is None else self.unit**2 + return VarianceUncertainty(new_array, unit=new_unit) + + @classmethod + def _convert_from_variance(cls, var_uncert): + new_array = None if var_uncert.array is None else var_uncert.array ** (1 / 2) + new_unit = None if var_uncert.unit is None else var_uncert.unit ** (1 / 2) + return cls(new_array, unit=new_unit) From 72300aa3b9fc154711676c73f7e9c6a7722f24e3 Mon Sep 17 00:00:00 2001 From: Shane Maloney Date: Fri, 21 Jun 2024 09:32:32 +0100 Subject: [PATCH 2/2] ruff rebase --- sunkit_spex/extern/ndcube/meta.py | 90 +++++++++---------- sunkit_spex/spectrum/spectrum.py | 77 ++++++++-------- sunkit_spex/spectrum/tests/test_spectrum.py | 4 +- .../spectrum/tests/test_uncertaintiy.py | 16 ++-- 4 files changed, 86 insertions(+), 101 deletions(-) diff --git a/sunkit_spex/extern/ndcube/meta.py b/sunkit_spex/extern/ndcube/meta.py index e7ef9573..b7ea6f12 100644 --- a/sunkit_spex/extern/ndcube/meta.py +++ b/sunkit_spex/extern/ndcube/meta.py @@ -72,6 +72,7 @@ class Meta(dict): axis-awareness. If specific pieces of metadata have a known way to behave during rebinning, this can be handled by subclasses or mixins. """ + def __init__(self, header=None, comments=None, axes=None, data_shape=None): self.__ndcube_can_slice__ = True self.__ndcube_can_rebin__ = True @@ -89,8 +90,7 @@ def __init__(self, header=None, comments=None, axes=None, data_shape=None): else: comments = dict(comments) if not set(comments.keys()).issubset(set(header_keys)): - raise ValueError( - "All comments must correspond to a value in header under the same key.") + raise ValueError("All comments must correspond to a value in header under the same key.") self._comments = comments if data_shape is None: @@ -101,40 +101,42 @@ def __init__(self, header=None, comments=None, axes=None, data_shape=None): if axes is None: self._axes = dict() else: - if not (isinstance(data_shape, collections.abc.Iterable) and - all([isinstance(i, numbers.Integral) for i in data_shape])): - raise TypeError("If axes is set, data_shape must be an iterable giving " - "the length of each axis of the associated cube.") + if not ( + isinstance(data_shape, collections.abc.Iterable) + and all([isinstance(i, numbers.Integral) for i in data_shape]) + ): + raise TypeError( + "If axes is set, data_shape must be an iterable giving " + "the length of each axis of the associated cube." + ) axes = dict(axes) if not set(axes.keys()).issubset(set(header_keys)): - raise ValueError( - "All axes must correspond to a value in header under the same key.") - self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key)) - for key, axis in axes.items()]) + raise ValueError("All axes must correspond to a value in header under the same key.") + self._axes = dict([(key, self._sanitize_axis_value(axis, header[key], key)) for key, axis in axes.items()]) def _sanitize_axis_value(self, axis, value, key): - axis_err_msg = ("Values in axes must be an integer or iterable of integers giving " - f"the data axis/axes associated with the metadata. axis = {axis}.") + axis_err_msg = ( + "Values in axes must be an integer or iterable of integers giving " + f"the data axis/axes associated with the metadata. axis = {axis}." + ) if isinstance(axis, numbers.Integral): axis = (axis,) if len(axis) == 0: return ValueError(axis_err_msg) if self.shape is None: - raise TypeError("Meta instance does not have a shape so new metadata " - "cannot be assigned to an axis.") + raise TypeError("Meta instance does not have a shape so new metadata " "cannot be assigned to an axis.") # Verify each entry in axes is an iterable of ints or a scalar. - if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral) - for i in axis])): + if not (isinstance(axis, collections.abc.Iterable) and all([isinstance(i, numbers.Integral) for i in axis])): return ValueError(axis_err_msg) axis = np.asarray(axis) if _not_scalar(value): axis_shape = tuple(self.shape[axis]) if not _is_grid_aligned(value, axis_shape) and not _is_axis_aligned(value, axis_shape): raise ValueError( - f"{key} must have shape {tuple(self.shape[axis])} " - f"as its associated axes {axis}, ", + f"{key} must have shape {tuple(self.shape[axis])} " f"as its associated axes {axis}, ", f"or same length as number of associated axes ({len(axis)}). " - f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}") + f"Has shape {value.shape if hasattr(value, 'shape') else len(value)}", + ) return axis @property @@ -172,8 +174,7 @@ def add(self, name, value, comment=None, axis=None, overwrite=False): If True, overwrites the entry of the name name if already present. """ if name in self.keys() and overwrite is not True: - raise KeyError(f"'{name}' already exists. " - "To update an existing metadata entry set overwrite=True.") + raise KeyError(f"'{name}' already exists. " "To update an existing metadata entry set overwrite=True.") if comment is not None: self._comments[name] = comment if axis is not None: @@ -202,7 +203,8 @@ def __setitem__(self, key, val): f"must either have same length as number associated axes ({len(axis)}), " f"or the same shape as associated data axes {tuple(self.shape[axis])}. " f"val shape = {val.shape if hasattr(val, 'shape') else (len(val),)}\n" - "We recommend using the 'add' method to set values.") + "We recommend using the 'add' method to set values." + ) super().__setitem__(key, val) def __getitem__(self, item): @@ -222,8 +224,7 @@ def __getitem__(self, item): if isinstance(item, (numbers.Integral, slice)): item = [item] naxes = len(self.shape) - item = np.array(list(item) + [slice(None)] * (naxes - len(item)), - dtype=object) + item = np.array(list(item) + [slice(None)] * (naxes - len(item)), dtype=object) # Edit data shape and calculate which axis will be dropped. dropped_axes = np.zeros(naxes, dtype=bool) @@ -244,8 +245,7 @@ def __getitem__(self, item): stop = self.shape[i] - stop new_shape[i] = stop - start else: - raise TypeError("Unrecognized slice type. " - "Must be an int, slice and tuple of the same.") + raise TypeError("Unrecognized slice type. " "Must be an int, slice and tuple of the same.") kept_axes = np.invert(dropped_axes) new_meta._data_shape = new_shape[kept_axes] @@ -255,9 +255,7 @@ def __getitem__(self, item): drop_key = False if axis is not None: # Calculate new axis indices. - new_axis = np.asarray(list( - set(axis).intersection(set(np.arange(naxes)[kept_axes])) - )) + new_axis = np.asarray(list(set(axis).intersection(set(np.arange(naxes)[kept_axes])))) if len(new_axis) == 0: new_axis = None else: @@ -280,11 +278,11 @@ def __getitem__(self, item): # Slice metadata value. try: new_value = value[new_item] - except: + except: # noqa: E722 # If value cannot be sliced by fancy slicing, convert it # it to an array, slice it, and then if necessary, convert # it back to its original type. - new_value = (np.asanyarray(value)[new_item]) + new_value = np.asanyarray(value)[new_item] if hasattr(new_value, "__len__"): new_value = type(value)(new_value) # If axis-aligned metadata sliced down to length 1, convert to scalar. @@ -294,8 +292,7 @@ def __getitem__(self, item): if drop_key: new_meta.remove(key) else: - new_meta.add(key, new_value, self.comments.get(key, None), new_axis, - overwrite=True) + new_meta.add(key, new_value, self.comments.get(key, None), new_axis, overwrite=True) return new_meta @@ -318,25 +315,26 @@ def rebin(self, rebinned_axes, new_shape): # Sanitize input. data_shape = self.shape if not isinstance(rebinned_axes, set): - raise TypeError( - f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}") + raise TypeError(f"rebinned_axes must be a set. type of rebinned_axes is {type(rebinned_axes)}") if not all([isinstance(dim, numbers.Integral) for dim in rebinned_axes]): raise ValueError("All elements of rebinned_axes must be ints.") list_axes = list(rebinned_axes) if min(list_axes) < 0 or max(list_axes) >= len(data_shape): - raise ValueError( - f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.") + raise ValueError(f"Elements in rebinned_axes must be in range 0--{len(data_shape)-1} inclusive.") if len(new_shape) != len(data_shape): - raise ValueError(f"new_shape must be a tuple of same length as data shape: " - f"{len(new_shape)} != {len(self.shape)}") + raise ValueError( + f"new_shape must be a tuple of same length as data shape: " f"{len(new_shape)} != {len(self.shape)}" + ) if not all([isinstance(dim, numbers.Integral) for dim in new_shape]): raise TypeError("bin_shape must contain only integer types.") # Remove axis-awareness from grid-aligned metadata associated with rebinned axes. new_meta = copy.deepcopy(self) null_set = set() for name, axes in self.axes.items(): - if (_is_grid_aligned(self[name], tuple(self.shape[axes])) - and set(axes).intersection(rebinned_axes) != null_set): + if ( + _is_grid_aligned(self[name], tuple(self.shape[axes])) + and set(axes).intersection(rebinned_axes) != null_set + ): del new_meta._axes[name] # Update data shape. new_meta._data_shape = np.asarray(new_shape).astype(int) @@ -344,15 +342,7 @@ def rebin(self, rebinned_axes, new_shape): def _not_scalar(value): - return ( - ( - hasattr(value, "shape") - or hasattr(value, "__len__") - ) - and not - ( - isinstance(value, str) - )) + return (hasattr(value, "shape") or hasattr(value, "__len__")) and not (isinstance(value, str)) def _is_scalar(value): diff --git a/sunkit_spex/spectrum/spectrum.py b/sunkit_spex/spectrum/spectrum.py index 3d93c522..305782d5 100644 --- a/sunkit_spex/spectrum/spectrum.py +++ b/sunkit_spex/spectrum/spectrum.py @@ -8,7 +8,7 @@ from astropy.modeling.tabular import Tabular1D from astropy.utils import lazyproperty -__all__ = ['gwcs_from_array', 'SpectralAxis', 'Spectrum'] +__all__ = ["gwcs_from_array", "SpectralAxis", "Spectrum"] def gwcs_from_array(array): @@ -18,38 +18,29 @@ def gwcs_from_array(array): """ orig_array = u.Quantity(array) - coord_frame = cf.CoordinateFrame(naxes=1, - axes_type=('SPECTRAL',), - axes_order=(0,)) + coord_frame = cf.CoordinateFrame(naxes=1, axes_type=("SPECTRAL",), axes_order=(0,)) spec_frame = cf.SpectralFrame(unit=array.unit, axes_order=(0,)) # In order for the world_to_pixel transformation to automatically convert # input units, the equivalencies in the lookup table have to be extended # with spectral unit information. - SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), - {'input_units_equivalencies': {'x0': u.spectral()}}) + SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}}) - forward_transform = SpectralTabular1D(np.arange(len(array)), - lookup_table=array) + forward_transform = SpectralTabular1D(np.arange(len(array)), lookup_table=array) # If our spectral axis is in descending order, we have to flip the lookup # table to be ascending in order for world_to_pixel to work. if len(array) == 0 or array[-1] > array[0]: - forward_transform.inverse = SpectralTabular1D( - array, lookup_table=np.arange(len(array))) + forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array))) else: - forward_transform.inverse = SpectralTabular1D( - array[::-1], lookup_table=np.arange(len(array))[::-1]) + forward_transform.inverse = SpectralTabular1D(array[::-1], lookup_table=np.arange(len(array))[::-1]) class SpectralGWCS(GWCS): def pixel_to_world(self, *args, **kwargs): - if orig_array.unit == '': + if orig_array.unit == "": return u.Quantity(super().pixel_to_world_values(*args, **kwargs)) - return super().pixel_to_world(*args, **kwargs).to( - orig_array.unit, equivalencies=u.spectral()) + return super().pixel_to_world(*args, **kwargs).to(orig_array.unit, equivalencies=u.spectral()) - tabular_gwcs = SpectralGWCS(forward_transform=forward_transform, - input_frame=coord_frame, - output_frame=spec_frame) + tabular_gwcs = SpectralGWCS(forward_transform=forward_transform, input_frame=coord_frame, output_frame=spec_frame) # Store the intended unit from the origin input array # tabular_gwcs._input_unit = orig_array.unit @@ -73,12 +64,13 @@ class SpectralAxis(SpectralCoord): _equivalent_unit = SpectralCoord._equivalent_unit + (u.pixel,) def __new__(cls, value, *args, bin_specification="centers", **kwargs): - # Enforce pixel axes are ascending - if ((type(value) is u.quantity.Quantity) and - (value.size > 1) and - (value.unit is u.pix) and - (value[-1] <= value[0])): + if ( + (type(value) is u.quantity.Quantity) + and (value.size > 1) + and (value.unit is u.pix) + and (value[-1] <= value[0]) + ): raise ValueError("u.pix spectral axes should always be ascending") # Convert to bin centers if bin edges were given, since SpectralCoord @@ -101,10 +93,10 @@ def _edges_from_centers(centers, unit): centers, with the two outer edges based on extrapolated centers added to the beginning and end of the spectral axis. """ - a = np.insert(centers, 0, 2*centers[0] - centers[1]) - b = np.append(centers, 2*centers[-1] - centers[-2]) + a = np.insert(centers, 0, 2 * centers[0] - centers[1]) + b = np.append(centers, 2 * centers[-1] - centers[-2]) edges = (a + b) / 2 - return edges*unit + return edges * unit @staticmethod def _centers_from_edges(edges): @@ -119,7 +111,7 @@ def bin_edges(self): Calculates bin edges if the spectral axis was created with centers specified. """ - if hasattr(self, '_bin_edges'): + if hasattr(self, "_bin_edges"): return self._bin_edges else: return self._edges_from_centers(self.value, self.unit) @@ -175,9 +167,9 @@ class Spectrum(NDCube): Data Type: float64 """ - def __init__(self, data, *, uncertainty=None, spectral_axis=None, - spectral_dimension=0, mask=None, meta=None, **kwargs): - + def __init__( + self, data, *, uncertainty=None, spectral_axis=None, spectral_dimension=0, mask=None, meta=None, **kwargs + ): # If the flux (data) argument is already a Spectrum (as it would # be for internal arithmetic operations), avoid setup entirely. if isinstance(data, Spectrum): @@ -193,22 +185,21 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None, # Ensure that the unit information codified in the quantity object is # the One True Unit. - kwargs.setdefault('unit', data.unit if isinstance(data, u.Quantity) - else kwargs.get('unit')) + kwargs.setdefault("unit", data.unit if isinstance(data, u.Quantity) else kwargs.get("unit")) # If flux and spectral axis are both specified, check that their lengths # match or are off by one (implying the spectral axis stores bin edges) if data is not None and spectral_axis is not None: if spectral_axis.shape[0] == data.shape[spectral_dimension]: bin_specification = "centers" - elif spectral_axis.shape[0] == data.shape[spectral_dimension]+1: + elif spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: bin_specification = "edges" else: raise ValueError( - "Spectral axis length ({}) must be the same size or one " + f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one " "greater (if specifying bin edges) than that of the spextral" - "axis ({})".format(spectral_axis.shape[0], - data.shape[spectral_dimension])) + f"axis ({data.shape[spectral_dimension]})" + ) # Attempt to parse the spectral axis. If none is given, try instead to # parse a given wcs. This is put into a GWCS object to @@ -216,8 +207,7 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None, if spectral_axis is not None: # Ensure that the spectral axis is an astropy Quantity if not isinstance(spectral_axis, u.Quantity): - raise ValueError("Spectral axis must be a `Quantity` or " - "`SpectralAxis` object.") + raise ValueError("Spectral axis must be a `Quantity` or " "`SpectralAxis` object.") # If a spectral axis is provided as an astropy Quantity, convert it # to a SpectralAxis object. @@ -226,12 +216,15 @@ def __init__(self, data, *, uncertainty=None, spectral_axis=None, bin_specification = "edges" else: bin_specification = "centers" - self._spectral_axis = SpectralAxis( - spectral_axis, - bin_specification=bin_specification) + self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) wcs = gwcs_from_array(self._spectral_axis) super().__init__( data=data.value if isinstance(data, u.Quantity) else data, - wcs=wcs, mask=mask, meta=meta, uncertainty=uncertainty, **kwargs) + wcs=wcs, + mask=mask, + meta=meta, + uncertainty=uncertainty, + **kwargs, + ) diff --git a/sunkit_spex/spectrum/tests/test_spectrum.py b/sunkit_spex/spectrum/tests/test_spectrum.py index 837208a3..1763bb05 100644 --- a/sunkit_spex/spectrum/tests/test_spectrum.py +++ b/sunkit_spex/spectrum/tests/test_spectrum.py @@ -7,10 +7,10 @@ def test_spectrum_bin_edges(): - spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=np.arange(1, 12)*u.keV) + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV) assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV) def test_spectrum_bin_centers(): - spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV) + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV) assert_array_equal(spec._spectral_axis, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5] * u.keV) diff --git a/sunkit_spex/spectrum/tests/test_uncertaintiy.py b/sunkit_spex/spectrum/tests/test_uncertaintiy.py index 714c1ca6..b6c19270 100644 --- a/sunkit_spex/spectrum/tests/test_uncertaintiy.py +++ b/sunkit_spex/spectrum/tests/test_uncertaintiy.py @@ -12,8 +12,8 @@ def test_add(): a = NDDataRef(data, uncertainty=PoissonUncertainty(uncert)) b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) aplusb = a.add(b) - assert_array_equal(aplusb.data, 2*data) - assert_array_equal(aplusb.uncertainty.array, np.sqrt(2*uncert ** 2)) + assert_array_equal(aplusb.data, 2 * data) + assert_array_equal(aplusb.uncertainty.array, np.sqrt(2 * uncert**2)) def test_subtract(): @@ -22,8 +22,8 @@ def test_subtract(): a = NDDataRef(data, uncertainty=PoissonUncertainty(uncert)) b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) aminusb = a.subtract(b) - assert_array_equal(aminusb.data, data-data) - assert_array_equal(aminusb.uncertainty.array, np.sqrt(2*uncert ** 2)) + assert_array_equal(aminusb.data, data - data) + assert_array_equal(aminusb.uncertainty.array, np.sqrt(2 * uncert**2)) def test_multiply(): @@ -33,7 +33,7 @@ def test_multiply(): b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) atimesb = a.multiply(b) assert_array_equal(atimesb.data, data**2) - assert_array_equal(atimesb.uncertainty.array, np.sqrt((2*data**2*uncert**2))) # (b**2*da**2 + a**2db**2)**0.5 + assert_array_equal(atimesb.uncertainty.array, np.sqrt(2 * data**2 * uncert**2)) # (b**2*da**2 + a**2db**2)**0.5 def test_divide(): @@ -42,5 +42,7 @@ def test_divide(): a = NDDataRef(data, uncertainty=PoissonUncertainty(uncert)) b = NDDataRef(data.copy(), uncertainty=PoissonUncertainty(uncert.copy())) adivb = a.divide(b) - assert_array_equal(adivb.data, data/data) - assert_array_equal(adivb.uncertainty.array, np.sqrt(((1/data)**2 * uncert**2)*2)) # ((1/b)**2*da**2 + (a/b**2)**2db**2)**0.5 + assert_array_equal(adivb.data, data / data) + assert_array_equal( + adivb.uncertainty.array, np.sqrt(((1 / data) ** 2 * uncert**2) * 2) + ) # ((1/b)**2*da**2 + (a/b**2)**2db**2)**0.5