diff --git a/roughness/helpers.py b/roughness/helpers.py index c9a0bd2..b41dd6b 100644 --- a/roughness/helpers.py +++ b/roughness/helpers.py @@ -13,13 +13,14 @@ def lookup2xarray(lookups, rms_coords=None, inc_coords=None, nodata=-999): """ Convert list of default lookups to xarray.DataSet. - Parameters - ---------- - lookups (array): Lookup tables (e.g. from make_los_table). + Parameters: + lookups (array): Lookup tables (e.g. from make_los_table). + rms_coords (array): RMS slope coordinates (default: None). + inc_coords (array): Incidence angle coordinates (default: None). + nodata (float): Value to replace NaNs with (default: -999). - Returns - ------- - xarray.DataArray: xarray.DataArray with labeled dims and coords. + Returns: + xarray.DataArray: xarray.DataArray with labeled dims and coords. """ names = cfg.LUT_NAMES longnames = cfg.LUT_LONGNAMES @@ -45,18 +46,16 @@ def np2xr(arr, dims=None, coords=None, name=None, cnames=None, cunits=None): """ Convert numpy array to xarray.DataArray. - Parameters - ---------- - array (np.array): Numpy array to convert to xarray.DataArray. - dims (list of str): Dimensions name of each param (default: index order). - coords (list of arr): Coordinate arrays of each dim. - name (str): Name of xarray.DataArray (default: None) - cnames (list of str): Coordinate names of each param. - cunits (list of str): Coordinate units of each param (default: deg). + Parameters: + array (np.array): Numpy array to convert to xarray.DataArray. + dims (list of str): Dimension names (default: cfg.LUT_DIMS). + coords (list of arr): Coordinate arrays of each dim. + name (str): Name of xarray.DataArray (default: None) + cnames (list of str): Coordinate names of each param. + cunits (list of str): Coordinate units of each param (default: deg). - Returns - ------- - xarray.DataArray + Returns: + xarray.DataArray """ ndims = len(arr.shape) if dims is None: @@ -80,7 +79,16 @@ def np2xr(arr, dims=None, coords=None, name=None, cnames=None, cunits=None): def wl2xr(arr, units="μm"): - """Return wavelength numpy array as xarray.""" + """ + Return wavelength numpy array as xarray. + + Parameters: + arr (array): Wavelength array. + units (str): Units of wavelengths (default: μm). + + Returns: + xarray.DataArray + """ da = xr.DataArray(arr, coords=[("wavelength", arr)]) da.coords["wavelength"].attrs["long_name"] = "Wavelength" da.coords["wavelength"].attrs["units"] = units @@ -88,7 +96,16 @@ def wl2xr(arr, units="μm"): def wn2xr(arr, units="cm^-1"): - """Return wavenumber numpy array as xarray.""" + """ + Return wavenumber numpy array as xarray. + + Parameters: + arr (array): Wavenumber array. + units (str): Units of wavenumbers (default: cm^-1). + + Returns: + xarray.DataArray + """ da = xr.DataArray(arr, coords=[("wavenumber", arr)]) da.coords["wavelength"].attrs["long_name"] = "Wavenumber" da.coords["wavelength"].attrs["units"] = units @@ -96,7 +113,18 @@ def wn2xr(arr, units="cm^-1"): def spec2xr(arr, wls, units="W/m²/sr/μm", wl_units="μm"): - """Return spectral numpy array as xarray.""" + """ + Return spectral numpy array as xarray. + + Parameters: + arr (array): Spectral array. + wls (array): Wavelength array. + units (str): Units of spectral array (default: W/m²/sr/μm). + wl_units (str): Units of wavelength array (default: μm). + + Returns: + xarray.DataArray + """ da = xr.DataArray(arr, coords=[wls], dims=["wavelength"], name="Radiance") da.attrs["units"] = units da.coords["wavelength"].attrs["long_name"] = "Wavelength" @@ -108,16 +136,14 @@ def get_lookup_coords(nrms=10, ninc=10, naz=36, ntheta=45): """ Return lookup table coordinate arrays - Parameters - ---------- - nrms (int): Number of RMS slopes in [0, 50] degrees. - ninc (int): Number of incidence angles in [0, 90] degrees. - naz (int): Number of facet azimuth bins in [0, 360] degrees. - ntheta (int): Number of facet slope bins in [0, 90] degrees. + Parameters: + nrms (int): Number of RMS slopes in [0, 50] degrees. + ninc (int): Number of incidence angles in [0, 90] degrees. + naz (int): Number of facet azimuth bins in [0, 360] degrees. + ntheta (int): Number of facet slope bins in [0, 90] degrees. - Returns - ------- - lookup_coords (list of array): Coordinate arrays (rms, cinc, az, theta) + Returns: + lookup_coords (list of array): Coord arrays (rms, cinc, az, theta) """ rmss, incs, azs, thetas = get_lookup_bins(nrms, ninc, naz, ntheta) # Get slope and az bin centers @@ -136,16 +162,15 @@ def get_lookup_bins(nrms=10, ninc=10, naz=36, ntheta=45): az: [0, 360] degrees theta: [0, 90] degrees - Parameters - ---------- - nrms (int): Number of RMS slopes in [0, 50] degrees. - ninc (int): Number of incidence angles in [0, 90] degrees. - naz (int): Number of facet azimuth bins in [0, 360] degrees. - ntheta (int): Number of facet slope bins in [0, 90] degrees. + Parameters: + nrms (int): Number of RMS slopes in [0, 50] degrees. + ninc (int): Number of incidence angles in [0, 90] degrees. + naz (int): Number of facet azimuth bins in [0, 360] degrees. + ntheta (int): Number of facet slope bins in [0, 90] degrees. Return ------ - lookup_coords (tuple of array): Coordinate arrays (rms, cinc, az, theta) + lookup_coords (tuple of array): Coordinate bins (rms, cinc, az, theta) """ rms_coords = np.linspace(0, 45, nrms) cinc_coords = np.linspace(1, 0, ninc) # cos([0, 90] degrees) @@ -163,14 +188,13 @@ def facet_grids(los_table, units="degrees"): az: [0, 360] degrees theta: [0, 90] degrees - Parameters - ---------- - los_table (arr): Line of sight table (dims: az, theta) - units (str): Return grids in specified units ('degrees' or 'radians') + Parameters: + los_table (arr): Line of sight table (dims: az, theta) + units (str): Return grids in specified units ('degrees' or 'radians') Return ------ - thetas, azs (tuple of 2D array): Coordinate grids of facet slope and az + thetas, azs (tuple of 2D array): Coord grids of facet slope and az """ if isinstance(los_table, xr.DataArray): az_arr = los_table.az.values @@ -208,14 +232,13 @@ def get_facet_bins(naz=36, ntheta=45): az: [0, 360] degrees theta: [0, 90] degrees - Parameters - ---------- - los_table (array): Line of sight table (dims: az, theta) - units (str): Return grids in specified units ('degrees' or 'radians') + Parameters: + los_table (array): Line of sight table (dims: az, theta) + units (str): Return grids in specified units ('degrees' or 'radians') Return ------ - thetas, azs (tuple of 2D array): Bin edges of facet slope and az + thetas, azs (tuple of 2D array): Bin edges of facet slope and az """ azim_coords = np.linspace(0, 360, naz + 1) slope_coords = np.linspace(0, 90, ntheta + 1) @@ -234,14 +257,12 @@ def fname_with_demsize(filename, demsize): """ Return filename with demsize appended to the end. - Parameters - ---------- - fname (str or Path): Filename to append to. - demsize (int): Length of dem in pixels. + Parameters: + fname (str or Path): Filename to append to. + demsize (int): Length of dem in pixels. - Returns - ------- - fname_with_demsize (str): New filename with new demsize appended. + Returns: + fname_with_demsize (str): New filename with new demsize appended. """ filename = Path(filename) return filename.with_name(f"{filename.stem}_s{demsize}{filename.suffix}") @@ -352,27 +373,23 @@ def build_jupyter_notebooks(nbpath=cfg.EXAMPLES_DIR): # Geometry helpers -def get_surf_geometry(ground_zen, ground_az, sun_zen, sun_az, sc_zen, sc_az): - """ - Return local i, e, g, azimuth give input viewing geometry assuming all - input zeniths and azimuths are in radians. - """ - ground = sph2cart(ground_zen, ground_az) - sun = sph2cart(sun_zen, sun_az) - sc = sph2cart(sc_zen, sc_az) - inc = get_angle_between(ground, sun) - em = get_angle_between(ground, sc) - phase = get_angle_between(sun, sc) - az = get_local_az(ground, sun, sc) - return (inc, em, phase, az) - - def get_ieg(ground_zen, ground_az, sun_zen, sun_az, sc_zen, sc_az): """ - Return local solar incidence, spacecraft emission and phase angle (i, e, g) - and azimuth given input viewing geometry. + Return local solar incidence, spacecraft emergence and phase angles + (i, e, g) and azimuth given input viewing geometry. Input zeniths and azimuths in radians. + + Parameters: + ground_zen (float): Ground zenith angle [rad]. + ground_az (float): Ground azimuth angle [rad]. + sun_zen (float): Sun zenith angle [rad]. + sun_az (float): Sun azimuth angle [rad]. + sc_zen (float): Spacecraft zenith angle [rad]. + sc_az (float): Spacecraft azimuth angle [rad]. + + Returns: + (i, e, g, az) (tuple of float): Geometry relative to ground. """ ground = sph2cart(ground_zen, ground_az) sun = sph2cart(sun_zen, sun_az) @@ -386,10 +403,21 @@ def get_ieg(ground_zen, ground_az, sun_zen, sun_az, sc_zen, sc_az): def get_ieg_xr(ground_zen, ground_az, sun_zen, sun_az, sc_zen, sc_az): """ - Return local solar incidence, spacecraft emission and phase angle (i, e, g) - and azimuth given input viewing geometry. + Return local solar incidence, spacecraft emergence and phase angles + (i, e, g) and azimuth given input viewing geometry as xarray.DataArray. Input zeniths and azimuths in radians. + + Parameters: + ground_zen (DataArray): Ground zenith angle [rad]. + ground_az (DataArray): Ground azimuth angle [rad]. + sun_zen (DataArray): Sun zenith angle [rad]. + sun_az (DataArray): Sun azimuth angle [rad]. + sc_zen (DataArray): Spacecraft zenith angle [rad]. + sc_az (DataArray): Spacecraft azimuth angle [rad]. + + Returns: + (i, e, g, az) (tuple of DataArray): Geometry relative to ground. """ ground = sph2cart_xr(ground_zen, ground_az) sun = sph2cart_xr(sun_zen, sun_az) @@ -406,6 +434,10 @@ def get_angle_between(vec1, vec2, safe_arccos=False): Return angle between cartesian vec1 and vec2 using dot product. If vec1 or vec2 are NxMx3, compute element-wise. + + Parameters: + vec1, vec2 (array): Cartesian vectors. + safe_arccos (bool): Clip dot product to [-1, 1] before arccos. """ if isinstance(vec1, xr.DataArray) or isinstance(vec2, xr.DataArray): return get_angle_between_xr(vec1, vec2, safe_arccos) @@ -417,7 +449,13 @@ def get_angle_between(vec1, vec2, safe_arccos=False): def get_angle_between_xr(vec1, vec2, safe_arccos=False): - """Return angle between cartesian vec1 and vec2 using dot product.""" + """ + Return angle between cartesian DataArrays using dot product. + + Parameters: + vec1, vec2 (DataArray): Cartesian vectors. + safe_arccos (bool): Clip dot product to [-1, 1] before arccos. + """ dot = vec1.dot(vec2, dims="xyz") if safe_arccos: # Restrict dot product to [-1, 1] to safely pass to arccos @@ -428,6 +466,10 @@ def get_angle_between_xr(vec1, vec2, safe_arccos=False): def get_azim(ground_sc_ground, ground_sun_ground): """ Return the azimuth arccos(dot product of the spacecraft and sun vectors) + + Parameters: + ground_sc_ground (array): Ground to spacecraft vector. + ground_sun_ground (array): Ground to sun vector. """ dot_azim = np.degrees( np.arccos(np.sum(ground_sc_ground * ground_sun_ground, axis=2)) @@ -445,6 +487,11 @@ def get_local_az(ground, sun, sc): Return azimuth angle of the spacecraft with respect to the sun and local slope. Assumes inputs are same size and shape and are in 3D Cartesian coords (ixjxk). + + Parameters: + ground (array): Ground vector. + sun (array): Sun vector. + sc (array): Spacecraft vector. """ sc_rel_ground = element_norm(element_triple_cross(ground, sc, ground)) sun_rel_ground = element_norm(element_triple_cross(ground, sun, ground)) @@ -454,9 +501,16 @@ def get_local_az(ground, sun, sc): def get_local_az_xr(ground, sun, sc): - """Return azimuth angle of the spacecraft with respect to the sun and local + """ + Return azimuth angle of the spacecraft with respect to the sun and local slope. Assumes inputs are same size and shape and are in 3D - Cartesian coords (ixjxk).""" + Cartesian coords (ixjxk). + + Parameters: + ground (DataArray): Ground vector. + sun (DataArray): Sun vector. + sc (DataArray): Spacecraft vector. + """ sc_rel_ground = inc_rel_ground_xr(sc, ground) sun_rel_ground = inc_rel_ground_xr(sun, ground) az = get_angle_between_xr(sc_rel_ground, sun_rel_ground, True) @@ -481,14 +535,10 @@ def inc_to_tloc(inc, az, lat): """ Convert solar incidence and az to decimal local time (in 6-18h). - Parameters - ---------- - inc: (float) - Solar incidence in degrees [0, 90) - az: (float) - Solar azimuth in degrees [0, 360) - lat: (float) - Latitude in degrees [-90, 90) + Parameters: + inc (float): Solar incidence in degrees [0, 90) + az (float): Solar azimuth in degrees [0, 360) + lat (float): Latitude in degrees [-90, 90) """ inc, lat = np.deg2rad(inc), np.deg2rad(lat) hr_angle = np.arccos(np.cos(inc) / np.cos(lat)) @@ -506,17 +556,14 @@ def tloc_to_inc(tloc, lat=0, az=False): """ Return the solar incidence angle given the local time in hrs. - Parameters - ---------- - tloc: (float) - Local time (0, 24) [hrs] - lat: (float) - Latitude (-90, 90) [deg] + Parameters: + tloc (float) Local time (0, 24) [hrs] + lat (float) Latitude (-90, 90) [deg] + az (bool) Return azimuth angle (default: False) - Return - ---------- - inc: (float) - Solar incidence (0, 90) [deg] + Returns: + inc (float) Solar incidence angle [deg] + az (float) Solar azimuth angle (if az=True) [deg] """ latr = np.deg2rad(lat) hr_angle = np.deg2rad(15 * (12 - tloc)) # (morning is +; afternoon is -) @@ -533,6 +580,15 @@ def get_inc_az_from_subsolar(lon, lat, sslon, sslat): Return inc and az angle at (lat, lon) given subsolar point (sslat, sslon). From calculator at https://the-moon.us/wiki/Sun_Angle + + Parameters: + lon (float): Longitude of point of interest [deg] + lat (float): Latitude of point of interest [deg] + sslon (float): Subsolar longitude [deg] + sslat (float): Subsolar latitude [deg] + + Returns: + inc, az (float): Solar incidence and azimuth angles [deg] """ lon, lat, sslon, sslat, dlon = ( np.deg2rad(lon), @@ -554,18 +610,6 @@ def get_inc_az_from_subsolar(lon, lat, sslon, sslat): # Linear algebra -# def element_az_elev(v1, v2): -# """ -# Return azimuth and elevation of v2 relative to v1. -# -# untested -# """ -# v = v2 - v1 -# az = np.degrees(np.arctan2(v[:, :, 0], v[:, :, 1])) -# elev = np.degrees(np.arctan2(v[:, :, 2], np.sqrt(v[:, :, 0]** 2 + v[:, :, 1]**2))) -# return az, elev - - def as_cart3D(vecs): """Return list of vecs as shape (N, M, 3) if they are not already.""" for i, vec in enumerate(vecs): @@ -577,6 +621,12 @@ def as_cart3D(vecs): def element_cross(A, B): """ Return element-wise cross product of two 3D arrays in cartesian coords. + + Parameters: + A, B (array): 3D arrays of vectors in cartesian coords. + + Returns: + out (array): 3D array of vector cross products in cartesian coords. """ A, B = as_cart3D([A, B]) out = np.zeros_like(A) @@ -587,20 +637,44 @@ def element_cross(A, B): def element_dot(A, B): - """Return element-wise dot product of two 3D arr in Cartesian coords.""" + """ + Return element-wise dot product of two 3D arrays in Cartesian coords. + + Parameters: + A, B (array): 3D arrays of vectors in cartesian coords. + + Returns: + out (array): 2D array of vector dot products. + """ A, B = as_cart3D([A, B]) return np.sum(A * B, axis=2) def element_norm(A): - """Return input array of vectors normalized to length 1.""" + """ + Return input array of vectors normalized to length 1. + + Parameters: + A (array): 3D array of vectors in cartesian coords. + + Returns: + out (array): 3D array of normalized vectors. + """ A = as_cart3D([A])[0] mag = np.sqrt(np.sum(A**2, axis=2)) return A / mag[:, :, np.newaxis] def element_triple_cross(A, B, C): - """Return element-wise triple cross product of three 3D arr in Cartesian""" + """ + Return element-wise triple cross product of three 3D arr in Cartesian + + Parameters: + A, B, C (array): 3D arrays of vectors in cartesian coords. + + Returns: + out (array): 3D array of triple cross products. + """ A, B, C = as_cart3D([A, B, C]) return ( B * (element_dot(A, C))[:, :, np.newaxis] @@ -613,15 +687,12 @@ def cart2pol(x, y): """ Convert ordered coordinate pairs from Cartesian (X,Y) to polar (r,theta). - Parameters - ---------- - X: X component of ordered Cartesian coordinate pair. - Y: Y component of ordered Cartesian coordinate pair. + Parameters: + x,y (array): Cartesian coordinates - Returns - ------- - r: Distance from origin. - theta: Angle, in radians. + Returns: + r (array): Distance from origin. + theta (array): Polar angle, [radians]. """ r = np.sqrt(x**2 + y**2) theta = np.arctan2(y, x) @@ -632,15 +703,12 @@ def pol2cart(rho, phi): """ Convert ordered coordinate pairs from polar (r,theta) to Cartesian (X,Y). - Parameters - ---------- - r: Distance from origin. - theta: Angle, in radians. + Parameters: + rho (array): Distance from origin. + phi (array): Polar angle, [radians]. - Returns - ------- - X: X component of ordered Cartesian coordinate pair. - Y: Y component of ordered Cartesian coordinate pair. + Returns: + x,y (array): Cartesian coordinates """ x = rho * np.cos(phi) y = rho * np.sin(phi) @@ -656,15 +724,13 @@ def sph2cart(theta, phi, radius=1): z-axis (e.g., if theta and phi are int/float return 1x1x3 vector; if theta and phi are MxN arrays return 3D array of vectors MxNx3). - Parameters - ---------- - theta (num or array): Polar angle [rad]. - phi (num or array): Azimuthal angle [rad]. - radius (num or array): Radius (default=1). + Parameters: + theta (num or array): Polar angle [rad]. + phi (num or array): Azimuthal angle [rad]. + radius (num or array): Radius (default=1). - Returns - ------- - cartesian (array): Cartesian vector(s), same shape as theta and phi. + Returns: + (array): Cartesian vector(s), same shape as theta and phi. """ if isinstance(theta, xr.DataArray): theta = theta.values @@ -680,7 +746,17 @@ def sph2cart(theta, phi, radius=1): def sph2cart_xr(theta, phi, radius=1): - """Convert spherical to cartesian coordinates using xarray.""" + """ + Convert spherical to cartesian coordinates using xarray. + + Parameters: + theta (DataArray): Polar angle [rad]. + phi (DataArray): Azimuthal angle [rad]. + radius (num or array): Radius (default=1). + + Returns: + (DataArray): Cartesian vector(s), same shape as theta and phi. + """ return xr.concat( [ radius * np.sin(theta) * np.cos(phi), @@ -695,17 +771,13 @@ def cart2sph(x, y, z): """ Convert from cartesian (x, y, z) to spherical (theta, phi, r). - Parameters - ---------- - x: X component of ordered Cartesian coordinate pair. - y: Y component of ordered Cartesian coordinate pair. - z: Z component of ordered Cartesian coordinate pair. + Parameters: + x, y, z (array): Cartesian coordinates. - Returns - ------- - theta: Polar angle [rad]. - phi: Azimuthal angle [rad]. - radius: Radius. + Returns: + theta (array): Polar angle [rad]. + phi (array): Azimuthal angle [rad]. + radius (array): Radius. """ r = np.sqrt(x**2 + y**2 + z**2) theta = np.arctan2(np.sqrt(x**2 + y**2), z) @@ -714,7 +786,16 @@ def cart2sph(x, y, z): def xy2lonlat_coords(x, y, extent): - """Convert x,y coordinates to lat,lon coordinates.""" + """ + Convert x,y coordinates to lat,lon coordinates. + + Parameters: + x, y (array): x,y coordinates. + extent (tuple): Lon/lat extent (lonmin, lonmax, latmin, latmax). + + Returns: + lon, lat (array): Lon/lat coordinates. + """ lon = np.linspace(extent[0], extent[1], len(x)) lat = np.linspace(extent[3], extent[2], len(y)) return lon, lat @@ -722,7 +803,16 @@ def xy2lonlat_coords(x, y, extent): # Image I/O def xarr2geotiff(xarr, savefile, crs=cfg.MOON2000_ESRI): - """Write 3D xarray (bands, y, x) to geotiff with rasterio.""" + """ + Write xarray image to geotiff with rasterio. + + xarr can must have dims (wavelength, lat, lon) or (lat, lon). + + Parameters: + xarr (xarray.DataArray): Image to save. + savefile (str): Filename to save to. + crs (str): Coordinate reference system (default: MOON2000_ESRI). + """ xarr = xarr.rio.write_crs(crs) if "wavelength" in xarr.dims: xarr = xarr.transpose("wavelength", "lat", "lon") diff --git a/roughness/plotting.py b/roughness/plotting.py index f01dc41..c4f1280 100755 --- a/roughness/plotting.py +++ b/roughness/plotting.py @@ -30,13 +30,13 @@ def plot_slope_az_table( """ Plot a 2D line-of-sight table with facet slope vs facet azimuth. - Args: + Parameters: table (ndarray): A 2D array of slope vs azimuth. cmap_r (bool, optional): Use a reverse colormap. Default is False. clabel (str, optional): Label for the colorbar. Default is an empty string. - ax (matplotlib.axes.Axes, optional): The axes to plot on. If not - provided, a new axes will be generated. + ax (matplotlib.axes.Axes, optional): Axes to plot on. If not provided, + a new Axes object will be generated. proj (str, optional): The projection to use. Valid values are 'polar' and None. Default is None. vmin (float, optional): The minimum value for the colorbar. Default is @@ -122,30 +122,14 @@ def m3_imshow( Use imshow to display an M3 image. Specify wavelength, wl, to show the closest channel to that wavelength. - Parameters - ---------- - img : ndarray - M3 image. - xmin : int - Minimum x index to plot. - xmax : int - Maximum x index to plot. - ymin : int - Minimum y index to plot. - ymax : int - Maximum y index to plot. - wl : int, optional - Wavelength to show. Default is 750. - ax : matplotlib.axes.Axes, optional - Axes to plot on. If not provided, a new axes will be generated. - title : str, optional - Plot title. Default is None. - cmap : str, optional - Colormap. Default is "gray". - - Returns - ------- - None + Parameters: + img (ndarray): M3 image. + xmin, xmax, ymin, ymax (int): The x and y index ranges to plot. + wl (int, optional): Wavelength to show, gets nearest (default: 750nm). + ax (matplotlib.axes.Axes, optional): Axes to plot on. If not provided, + a new Axes object will be generated. + title (str, optional): Plot title (default: '') + cmap (str, optional): Colormap (default: "gray"). """ band = 0 if len(img.shape) > 2 else None if ax is None: @@ -163,22 +147,13 @@ def m3_spec(img, fmt="-", wls=None, ax=None, title=None, **kwargs): """ Plot M3 spectrum. - Parameters - ---------- - img : ndarray - M3 image. - fmt : str, optional - Plot format. Default is '-' for line. - wls : ndarray, optional - Wavelengths to plot. Default is all. - ax : matplotlib.axes.Axes, optional - Axes to plot on. If not provided, a new axes will be generated. - title : str, optional - Plot title. Default is None. - - Returns - ------- - None + Parameters: + img (ndarray): M3 image. + fmt (str, optional): Plot format (default: '-' for line). + wls (ndarray, optional): Wavelengths to plot (default: all). + ax (matplotlib.axes.Axes, optional): Axes to plot on. If not provided, + a new Axes object will be generated. + title (str, optional): Plot title (default: '') """ if title is None: title = "M3 spectrum" @@ -204,38 +179,22 @@ def dem_imshow( ymax=None, band="slope", ax=None, - title=None, + title="", cmap="viridis", **kwargs, ): """ Use imshow to display a stacked DEM image (slope, azimuth, elevation). - Parameters - ---------- - dem : ndarray - A 3D stacked DEM (slope, azimuth, elevation). - xmin : int - Minimum x index to plot. - xmax : int - Maximum x index to plot. - ymin : int - Minimum y index to plot. - ymax : int - Maximum y index to plot. - band : str or tuple - The band to display. Valid values are 'slope', 'azim', 'elev', or a - tuple of (0, 1, 2) respectively. - ax : matplotlib.axes.Axes, optional - Axes to plot on. If not provided, a new axes will be generated. - title : str, optional - Plot title. Default is None. - cmap : str, optional - Colormap. Default is "gray". - - Returns - ------- - None + Parameters: + dem (ndarray): A 3D stacked DEM (slope, azimuth, elevation). + xmin, xmax, ymin, ymax (int): The x and y index ranges to plot. + band (str or tuple): The band to display. Valid values are 'slope', + 'azim', 'elev', or a tuple of (0, 1, 2) respectively. + ax (matplotlib.axes.Axes, optional): Axes to plot on. If not provided, + a new Axes object will be generated. + title (str, optional): Plot title (default: '') + cmap (str, optional): Colormap (default: "gray"). """ zmap = {"slope": 0, "azim": 1, "elev": 2, 0: "slope", 1: "azim", 2: "elev"} if len(dem.shape) > 2 and dem.shape[2] > 1: @@ -245,10 +204,10 @@ def dem_imshow( zind = band else: raise ValueError(f"Unknown band {band} specified.") + if title is None: + title = f"DEM {zmap[zind]}" else: zind = None - if title is None: - title = f"DEM {zmap[zind]}" if ax is None: _, ax = plt.subplots() p = ax.imshow(dem[xmin:xmax, ymin:ymax, zind].T, cmap=cmap, **kwargs)