Skip to content

Commit

Permalink
bug fix, separated weighted output/target from original
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Lin committed Aug 21, 2023
1 parent f1ed40a commit e2d5efd
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions climsim_utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,27 +142,39 @@ def find_keys(dictionary, value):
self.target_train = None
self.preds_train = None
self.samples_train = None
self.target_weighted_train = None
self.preds_weighted_train = None
self.samples_weighted_train = None
self.metrics_idx_train = None
self.metrics_var_train = None

self.input_val = None
self.target_val = None
self.preds_val = None
self.samples_val = None
self.target_weighted_val = None
self.preds_weighted_val = None
self.samples_weighted_val = None
self.metrics_idx_val = None
self.metrics_var_val = None

self.input_scoring = None
self.target_scoring = None
self.preds_scoring = None
self.samples_scoring = None
self.target_weighted_scoring = None
self.preds_weighted_scoring = None
self.samples_weighted_scoring = None
self.metrics_idx_scoring = None
self.metrics_var_scoring = None

self.input_test = None
self.target_test = None
self.preds_test = None
self.samples_test = None
self.target_weighted_test = None
self.preds_weighted_test = None
self.samples_weighted_test = None
self.metrics_test = None
self.metrics_idx_test = None
self.metrics_var_test = None
Expand Down Expand Up @@ -544,16 +556,16 @@ def reweight_target(self, data_split):
assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.'
if data_split == 'train':
assert self.target_train is not None
self.target_train = self.output_weighting(self.target_train)
self.target_weighted_train = self.output_weighting(self.target_train)
elif data_split == 'val':
assert self.target_val is not None
self.target_val = self.output_weighting(self.target_val)
self.target_weighted_val = self.output_weighting(self.target_val)
elif data_split == 'scoring':
assert self.target_scoring is not None
self.target_scoring = self.output_weighting(self.target_scoring)
self.target_weighted_scoring = self.output_weighting(self.target_scoring)
elif data_split == 'test':
assert self.target_test is not None
self.target_test = self.output_weighting(self.target_test)
self.target_weighted_test = self.output_weighting(self.target_test)

def reweight_preds(self, data_split):
'''
Expand Down

0 comments on commit e2d5efd

Please sign in to comment.