Skip to content

Commit

Permalink
Merge pull request #59 from leap-stc/jerrydev
Browse files Browse the repository at this point in the history
Mostly cleanup from Monday + some edits to data_utils.py
  • Loading branch information
jerrylin96 committed Aug 25, 2023
2 parents dcced39 + 679557e commit 4ec2765
Show file tree
Hide file tree
Showing 7 changed files with 404 additions and 424 deletions.
93 changes: 71 additions & 22 deletions climsim_utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ def find_keys(dictionary, value):
self.hyam = self.grid_info['hyam'].values
self.hybm = self.grid_info['hybm'].values
self.p0 = 1e5 # code assumes this will always be a scalar
self.pressure_grid = None
self.dp = None

self.pressure_grid_train = None
self.pressure_grid_val = None
self.pressure_grid_scoring = None
self.pressure_grid_test = None

self.dp_train = None
self.dp_val = None
self.dp_scoring = None
self.dp_test = None

self.train_regexps = None
self.train_stride_sample = None
self.train_filelist = None
Expand Down Expand Up @@ -454,18 +463,49 @@ def load_h5_file(load_path = ''):
pred = np.array(hf.get('pred'))
return pred

def set_pressure_grid(self, input_arr):
def set_pressure_grid(self, data_split):
'''
This function sets the pressure weighting for metrics.
'''
state_ps = input_arr[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
state_ps = np.reshape(state_ps, (-1, self.latlonnum))
pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
self.pressure_grid = pressure_grid_p1 + pressure_grid_p2
self.dp = self.pressure_grid[1:61,:,:] - self.pressure_grid[0:60,:,:]
self.dp = self.dp.transpose((1,2,0))

assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'

if data_split == 'train':
assert self.input_train is not None
state_ps = self.input_train[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
state_ps = np.reshape(state_ps, (-1, self.latlonnum))
pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
self.pressure_grid_train = pressure_grid_p1 + pressure_grid_p2
self.dp_train = self.pressure_grid_train[1:61,:,:] - self.pressure_grid_train[0:60,:,:]
self.dp_train = self.dp_train.transpose((1,2,0))
elif data_split == 'val':
assert self.input_val is not None
state_ps = self.input_val[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
state_ps = np.reshape(state_ps, (-1, self.latlonnum))
pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
self.pressure_grid_val = pressure_grid_p1 + pressure_grid_p2
self.dp_val = self.pressure_grid_val[1:61,:,:] - self.pressure_grid_val[0:60,:,:]
self.dp_val = self.dp_val.transpose((1,2,0))
elif data_split == 'scoring':
assert self.input_scoring is not None
state_ps = self.input_scoring[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
state_ps = np.reshape(state_ps, (-1, self.latlonnum))
pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
self.pressure_grid_scoring = pressure_grid_p1 + pressure_grid_p2
self.dp_scoring = self.pressure_grid_scoring[1:61,:,:] - self.pressure_grid_scoring[0:60,:,:]
self.dp_scoring = self.dp_scoring.transpose((1,2,0))
elif data_split == 'test':
assert self.input_test is not None
state_ps = self.input_test[:,120]*(self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
state_ps = np.reshape(state_ps, (-1, self.latlonnum))
pressure_grid_p1 = np.array(self.grid_info['P0']*self.grid_info['hyai'])[:,np.newaxis,np.newaxis]
pressure_grid_p2 = self.grid_info['hybi'].values[:, np.newaxis, np.newaxis] * state_ps[np.newaxis, :, :]
self.pressure_grid_test = pressure_grid_p1 + pressure_grid_p2
self.dp_test = self.pressure_grid_test[1:61,:,:] - self.pressure_grid_test[0:60,:,:]
self.dp_test = self.dp_test.transpose((1,2,0))

def get_pressure_grid_plotting(self, data_split):
'''
This function creates the temporally and zonally averaged pressure grid corresponding to a given data split.
Expand All @@ -488,14 +528,15 @@ def find_keys(dictionary, value):
pressure_grid_plotting = np.concatenate(pg_lats, axis = 1)
return pressure_grid_plotting

def output_weighting(self, output):
def output_weighting(self, output, data_split):
'''
This function does four transformations, and assumes we are using V1 variables:
[0] Undos the output scaling
[1] Weight vertical levels by dp/g
[2] Weight horizontal area of each grid cell by a[x]/mean(a[x])
[3] Unit conversion to a common energy unit
'''
assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
num_samples = output.shape[0]
heating = output[:,:60].reshape((int(num_samples/self.latlonnum), self.latlonnum, 60))
moistening = output[:,60:120].reshape((int(num_samples/self.latlonnum), self.latlonnum, 60))
Expand Down Expand Up @@ -527,8 +568,16 @@ def output_weighting(self, output):
# [1] Weight vertical levels by dp/g
# only for vertically-resolved variables, e.g. ptend_{t,q0001}
# dp/g = -\rho * dz
heating = heating * self.dp/self.grav
moistening = moistening * self.dp/self.grav
if data_split == 'train':
dp = self.dp_train
elif data_split == 'val':
dp = self.dp_val
elif data_split == 'scoring':
dp = self.dp_scoring
elif data_split == 'test':
dp = self.dp_test
heating = heating * dp/self.grav
moistening = moistening * dp/self.grav

# [2] weight by area
heating = heating * self.area_wgt[np.newaxis, :, np.newaxis]
Expand Down Expand Up @@ -573,16 +622,16 @@ def reweight_target(self, data_split):
assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
if data_split == 'train':
assert self.target_train is not None
self.target_weighted_train = self.output_weighting(self.target_train)
self.target_weighted_train = self.output_weighting(self.target_train, data_split)
elif data_split == 'val':
assert self.target_val is not None
self.target_weighted_val = self.output_weighting(self.target_val)
self.target_weighted_val = self.output_weighting(self.target_val, data_split)
elif data_split == 'scoring':
assert self.target_scoring is not None
self.target_weighted_scoring = self.output_weighting(self.target_scoring)
self.target_weighted_scoring = self.output_weighting(self.target_scoring, data_split)
elif data_split == 'test':
assert self.target_test is not None
self.target_weighted_test = self.output_weighting(self.target_test)
self.target_weighted_test = self.output_weighting(self.target_test, data_split)

def reweight_preds(self, data_split):
'''
Expand All @@ -594,19 +643,19 @@ def reweight_preds(self, data_split):
if data_split == 'train':
assert self.preds_train is not None
for model_name in self.model_names:
self.preds_weighted_train[model_name] = self.output_weighting(self.preds_train[model_name])
self.preds_weighted_train[model_name] = self.output_weighting(self.preds_train[model_name], data_split)
elif data_split == 'val':
assert self.preds_val is not None
for model_name in self.model_names:
self.preds_weighted_val[model_name] = self.output_weighting(self.preds_val[model_name])
self.preds_weighted_val[model_name] = self.output_weighting(self.preds_val[model_name], data_split)
elif data_split == 'scoring':
assert self.preds_scoring is not None
for model_name in self.model_names:
self.preds_weighted_scoring[model_name] = self.output_weighting(self.preds_scoring[model_name])
self.preds_weighted_scoring[model_name] = self.output_weighting(self.preds_scoring[model_name], data_split)
elif data_split == 'test':
assert self.preds_test is not None
for model_name in self.model_names:
self.preds_weighted_test[model_name] = self.output_weighting(self.preds_test[model_name])
self.preds_weighted_test[model_name] = self.output_weighting(self.preds_test[model_name], data_split)

def calc_MAE(self, pred, target, avg_grid = True):
'''
Expand Down
Loading

0 comments on commit 4ec2765

Please sign in to comment.