Skip to content

Commit

Permalink
Merge pull request #53 from coecms/fix_issue_49
Browse files Browse the repository at this point in the history
Fix issue 49
  • Loading branch information
paolap committed Dec 6, 2022
2 parents 4c33dff + 91b4ff5 commit 4a7b67b
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 67 deletions.
2 changes: 1 addition & 1 deletion conda/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% set version = "0.8.0" %}
{% set version = "0.9.0" %}
package:
name: xmhw
version: {{ version }}
Expand Down
12 changes: 6 additions & 6 deletions test/test_identify.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def test_define_events(define_data, mhw_data, inter_data):
mhwds = mhw_data
interds = inter_data
res = define_events(
ts.isel(cell=0),
th.isel(cell=0),
se.isel(cell=0),
ts,
th,
se,
idxarr,
5,
True,
Expand All @@ -176,9 +176,9 @@ def test_define_events(define_data, mhw_data, inter_data):

# test define events return one dataset only if intemediate is False, as default
res = define_events(
ts.isel(cell=0),
th.isel(cell=0),
se.isel(cell=0),
ts,
th,
se,
idxarr,
5,
True,
Expand Down
24 changes: 14 additions & 10 deletions test/xmhw_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,28 @@ def define_data():
doy = xr.DataArray(
data=[1, 2, 3, 4, 5, 6, 7, 8, 9], dims=["time"], coords={"time": time}
)
lat = 45.5
lon = 123.4
ts = xr.DataArray(
data=[15.6, 17.3, 18.2, 19.5, 19.4, 19.6, 18.1, 17.0, 15.2],
dims=["time"],
coords={"time": time, "doy": doy},
coords={"time": time, "doy": doy, "lat": lat, "lon": lon},
)
# dims=['doy'], coords={'doy': doy, 'quantile':0.9})
se = xr.DataArray(
data=[15.8, 16.0, 16.2, 16.5, 16.6, 16.4, 16.6, 16.7, 16.4],
dims=["doy"],
coords={"doy": ts["doy"].values},
coords={"doy": ts["doy"].values,
"lat": lat, "lon": lon},
)
th = xr.DataArray(
[16.0, 16.7, 17.6, 17.9, 18.1, 18.2, 17.3, 17.2, 17.0],
dims=["doy"],
coords={"doy": ts["doy"].values, "quantile": 0.9},
coords={"doy": ts["doy"].values, "lat":lat, "lon": lon,
"quantile": 0.9},
)
ts = ts.expand_dims(["lat", "lon"])
ts = ts.stack(cell=(["lat", "lon"]))
se = se.expand_dims(["lat", "lon"])
th = th.expand_dims(["lat", "lon"])
se = se.stack(cell=(["lat", "lon"]))
th = th.stack(cell=(["lat", "lon"]))
ts = ts.stack(cell=(["lat", "lon"]), create_index=False)
se = se.stack(cell=(["lat", "lon"]), create_index=False)
th = th.stack(cell=(["lat", "lon"]), create_index=False)
# Build a pandas series with the positional indexes as values
# [0,1,2,3,4,5,6,7,8,9,10,..]
idxarr = pd.Series(data=np.arange(9), index=ts.time.values)
Expand Down Expand Up @@ -253,6 +253,8 @@ def mhw_data():
"duration": [6.0],
"rate_onset": [0.5888889],
"rate_decline": [1.5333333],
"lat": [45.5],
"lon": [123.4],
}
for k, v in vars_dict.items():
mhwds[k] = xr.DataArray(
Expand All @@ -267,6 +269,8 @@ def inter_data():
index = pd.date_range("2001-01-01", periods=9)
ids = xr.Dataset(coords={"index": index})
vars_dict = {
"lat": [45.5 for i in range(9)],
"lon": [123.4 for i in range(9)],
"ts": [15.6, 17.3, 18.2, 19.5, 19.4, 19.6, 18.1, 17.0, 15.2],
"seas": [np.nan, 16.0, 16.2, 16.5, 16.6, 16.4, 16.6, np.nan, np.nan],
"thresh": [np.nan, 16.7, 17.6, 17.9, 18.1, 18.2, 17.3, np.nan, np.nan],
Expand Down
45 changes: 26 additions & 19 deletions xmhw/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def mhw_df(df):
return df


def mhw_features(dftime, last):
def mhw_features(dftime, last, tdim, dims):
"""Calculate mhw properties, grouping by each event.
Parameters
Expand All @@ -86,15 +86,15 @@ def mhw_features(dftime, last):
"""

# calculate some of the mhw properties aggregating by events
df = agg_df(dftime)
df = agg_df(dftime, tdim, dims)
# calculate the rest of the mhw properties
df = properties(df, dftime.relThresh, dftime.mabs)
df = properties(df, dftime.loc[:,'relThresh'], dftime.loc[:,'mabs'])
# calculate onset decline rates
df = onset_decline(df, last)
return df


def agg_df(df):
def agg_df(df, tdim, dims):
"""Groupby events and apply different functions depending on attribute.
Parameters
Expand All @@ -111,12 +111,12 @@ def agg_df(df):
"""

# using an aggregation dictionary to avoid apply.
dfout = df.groupby("events").agg(
dfout = df.groupby("events", group_keys=True).agg(
event=("events", "first"),
index_start=("start", "first"),
index_end=("end", "first"),
time_start=("time", "first"),
time_end=("time", "last"),
time_start=(tdim, "first"),
time_end=(tdim, "last"),
relS_imax=("relSeas", np.argmax),
# time as dataframe index, instead
# of the timeseries index
Expand Down Expand Up @@ -149,7 +149,12 @@ def agg_df(df):
duration_strong=("duration_strong", "sum"),
duration_severe=("duration_severe", "sum"),
duration_extreme=("duration_extreme", "sum"),
)
)
# adding dimensions used in stacked cell to recreate cell later
# sending values to list to avoid warnings
for d in dims:
val = df[d].to_list()
dfout.loc[:,d] = val[0]
return dfout


Expand All @@ -172,18 +177,20 @@ def properties(df, relT, mabs):
As input but with more MHW properties added
"""

df["index_peak"] = df.event + df.relS_imax
df["intensity_var"] = np.sqrt(df.relS_var)
df["severity_var"] = np.sqrt(df.severity_var)
df["intensity_max_relThresh"] = relT[df.time_peak].values
df["intensity_max_abs"] = mabs[df.time_peak].values
df["intensity_var_relThresh"] = np.sqrt(df.relT_var)
df["intensity_var_abs"] = np.sqrt(df.mabs_var)
df["category"] = np.minimum(df.cats_max, 4)
df["duration"] = df.index_end - df.index_start + 1
df = df.drop(["relS_imax", "relS_var", "relT_var", "cats_max", "mabs_var"],
df2 = df.copy()
df2['index_peak'] = df.event + df.relS_imax
df2['intensity_var'] = np.sqrt(df.relS_var)
df2['severity_var'] = np.sqrt(df.severity_var)
df2['intensity_max_relThresh'] = relT[df.time_peak].values
df2['intensity_max_abs'] = mabs[df.time_peak].values
df2['intensity_var_relThresh'] = np.sqrt(df.relT_var)
df2['intensity_var_abs'] = np.sqrt(df.mabs_var)
df2['category'] = np.minimum(df.cats_max, 4)
df2['duration'] = df.index_end - df.index_start + 1
del df
df2 = df2.drop(['relS_imax', 'relS_var', 'relT_var', 'cats_max', 'mabs_var'],
axis=1)
return df
return df2


def get_rate(relSeas_peak, relSeas_edge, period):
Expand Down
37 changes: 21 additions & 16 deletions xmhw/identify.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def join_gaps(st, end, events, maxGap):

@dask.delayed(nout=2)
def define_events(ts, th, se, idxarr, minDuration, joinGaps, maxGap,
intermediate):
intermediate, tdim="time"):
"""Finds all MHW events of duration >= minDuration and calculate
their properties.
Expand Down Expand Up @@ -373,10 +373,20 @@ def define_events(ts, th, se, idxarr, minDuration, joinGaps, maxGap,
bthresh = ts > thresh
ds = xr.Dataset({"ts": ts, "seas": seas, "thresh": thresh,
"bthresh": bthresh})

# Convert xarray dataset to pandas dataframe, as groupby operations
# are faster in pandas
df = ds.to_dataframe()
dims = list(ds.coords)
# create a list of potential fields to remove from dims now and df later
fields_toremove = ["doy",
"time",
"start",
"end",
"anom_plus",
"anom_minus",
"quantile",
]
[dims.remove(x) for x in fields_toremove if x in dims]
del ds

# detect events
Expand All @@ -387,25 +397,17 @@ def define_events(ts, th, se, idxarr, minDuration, joinGaps, maxGap,
del dfev

# Calculate mhw properties, for each event using groupby
dfmhw = mhw_features(df, len(idxarr) - 1)
dfmhw = mhw_features(df, len(idxarr) - 1, tdim, dims)

# Convert back to xarray dataset
mhw = xr.Dataset.from_dataframe(dfmhw, sparse=False)
del dfmhw
mhw_inter = None
current_columns = df.columns.to_list()
toremove = [x for x in fields_toremove if x in current_columns]

if intermediate:
df = df.drop(
columns=[
"doy",
"cell",
"time",
"start",
"end",
"anom_plus",
"anom_minus",
"quantile",
]
)
df = df.drop(columns=toremove)
mhw_inter = xr.Dataset.from_dataframe(df, sparse=False)
del df
return mhw, mhw_inter
Expand Down Expand Up @@ -510,7 +512,10 @@ def land_check(temp, tdim="time", anynans=False):
for d in dims:
if len(temp[d]) == 0:
raise XmhwException(f"Dimension {d} has 0 lenght, exiting")
ts = temp.stack(cell=(dims))
# removing multi-index creation, as this was disappearing during percentile and mean operation anyway,
# and potentially slows down calculation
# adding sorted to be consistent if applying functipn to different arrays with same dimensions
ts = temp.stack(cell=(sorted(dims)), create_index=False)
# drop cells that have all/any nan values along time
how = "all"
if anynans:
Expand Down
7 changes: 5 additions & 2 deletions xmhw/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,13 @@ def check_coordinates(dstime):
# find name of time dimension
# If there is already a stacked dimension skip land_check
check = True
ds_coords = list(dstime.coords)
# now that we use a stacked array without crearting an index
# the stacked coord is a dimension without coordinates
# and its type is int64 not anymore object
ds_coords = list(dstime.dims)
for x in ds_coords:
dtype = str(dstime[x].dtype)
if dtype == "object":
if dtype == "int64":
stack_coord = x
check = False
elif "datetime" in dtype:
Expand Down
39 changes: 26 additions & 13 deletions xmhw/xmhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,20 @@ def threshold(

# Concatenate results and save as dataset
ds = xr.Dataset()
thresh_results = [r[0] for r in results[0]]
ds["thresh"] = xr.concat(thresh_results, dim=ts.cell)
#thresh_results = [r[0] for r in results[0]]
# apply temporary fix suggested by @bjnmr issue #49
# as newver version of xarray are removing coords when calculating quantile but not for mean
# as I removed the multiindex I'm passing directly r[1].coords and not r[1]['cell'].coords
# this causes issues when trying to concatenate
thresh_results = [r[0].assign_coords(r[1].coords) for r in results[0]]
ds["thresh"] = xr.concat(thresh_results, dim='cell')
ds.thresh.name = "threshold"
seas_results = [r[1] for r in results[0]]
ds["seas"] = xr.concat(seas_results, dim=ts.cell)
ds["seas"] = xr.concat(seas_results, dim='cell')
ds.seas.name = "seasonal"
ds = ds.unstack("cell")
dims = [k for k in ts.cell.coords.keys()]
ds = ds.set_xindex(dims)
ds = ds.unstack(dim='cell')

# add previously saved attributes to ds
ds = annotate_ds(ds, ds_attrs, "clim")
Expand Down Expand Up @@ -338,7 +345,7 @@ def detect(
+ " be smaller than event minimum duration"
)
# if time dimension different from time, rename it
temp = temp.rename({tdim: "time"})
#temp = temp.rename({tdim: "time"})
# save original attributes in a dictionary to assign to final dataset
ds_attrs = {}
ds_attrs["ts"] = temp.attrs
Expand All @@ -348,7 +355,7 @@ def detect(

# Returns an array stacked on all dimensions excluded time, doy
# Land cells are removed and new dimensions are (time,cell)
ts = land_check(temp, anynans=anynans)
ts = land_check(temp, tdim=tdim, anynans=anynans)
del temp
th = land_check(th, tdim="doy", anynans=anynans)
se = land_check(se, tdim="doy", anynans=anynans)
Expand All @@ -366,7 +373,7 @@ def detect(

# Build a pandas series with the positional indexes as values
# [0,1,2,3,4,5,6,7,8,9,10,..]
idxarr = pd.Series(data=np.arange(len(ts.time)), index=ts.time.values)
idxarr = pd.Series(data=np.arange(len(ts[tdim])), index=ts[tdim].values)

# Loop over each cell to detect MHW events, define_events()
# is delayed, so loop is automatically run in parallel
Expand All @@ -382,18 +389,24 @@ def detect(
joinGaps,
maxGap,
intermediate,
tdim,
)
)
results = dask.compute(mhwls)

# Concatenate results and save as dataset
mhw_results = [r[0] for r in results[0]]
mhw = xr.concat(mhw_results, dim=ts.cell)
mhw = mhw.unstack("cell")
# re-assign dimensions previously used to stack arrays
dims = list(ts.cell.coords)
mhw_results = [r[0].assign_coords({d: r[0][d][0].values for d in dims}) for r in results[0]]
mhw = xr.concat(mhw_results, dim='cell')
mhw = mhw.set_xindex(dims)
mhw = mhw.unstack(dim='cell')
if intermediate:
inter_results = [r[1] for r in results[0]]
mhw_inter = xr.concat(inter_results, dim=ts.cell).unstack("cell")
mhw_inter = mhw_inter.rename({"index": "time"})
inter_results = [r[1].assign_coords({d: r[1][d][0].values for d in dims}) for r in results[0]]
mhw_inter = xr.concat(inter_results, dim='cell')
mhw_inter = mhw_inter.set_xindex(dims)
mhw_inter = mhw_inter.unstack('cell')
mhw_inter = mhw_inter.rename({'index': 'time'})
mhw_inter = mhw_inter.squeeze(drop=True)
# if point dimension was added in land_check remove
mhw = mhw.squeeze(drop=True)
Expand Down

0 comments on commit 4a7b67b

Please sign in to comment.