diff --git a/climsim_utils/data_utils.py b/climsim_utils/data_utils.py index 520ef5c..e812e84 100644 --- a/climsim_utils/data_utils.py +++ b/climsim_utils/data_utils.py @@ -142,6 +142,9 @@ def find_keys(dictionary, value): self.target_train = None self.preds_train = None self.samples_train = None + self.target_weighted_train = None + self.preds_weighted_train = None + self.samples_weighted_train = None self.metrics_idx_train = None self.metrics_var_train = None @@ -149,6 +152,9 @@ def find_keys(dictionary, value): self.target_val = None self.preds_val = None self.samples_val = None + self.target_weighted_val = None + self.preds_weighted_val = None + self.samples_weighted_val = None self.metrics_idx_val = None self.metrics_var_val = None @@ -156,6 +162,9 @@ def find_keys(dictionary, value): self.target_scoring = None self.preds_scoring = None self.samples_scoring = None + self.target_weighted_scoring = None + self.preds_weighted_scoring = None + self.samples_weighted_scoring = None self.metrics_idx_scoring = None self.metrics_var_scoring = None @@ -163,6 +172,9 @@ def find_keys(dictionary, value): self.target_test = None self.preds_test = None self.samples_test = None + self.target_weighted_test = None + self.preds_weighted_test = None + self.samples_weighted_test = None self.metrics_test = None self.metrics_idx_test = None self.metrics_var_test = None @@ -544,16 +556,16 @@ def reweight_target(self, data_split): assert data_split in ['train', 'val', 'scoring', 'test'], 'Provided data_split is not valid. Available options are train, val, scoring, and test.' if data_split == 'train': assert self.target_train is not None - self.target_train = self.output_weighting(self.target_train) + self.target_weighted_train = self.output_weighting(self.target_train) elif data_split == 'val': assert self.target_val is not None - self.target_val = self.output_weighting(self.target_val) + self.target_weighted_val = self.output_weighting(self.target_val) elif data_split == 'scoring': assert self.target_scoring is not None - self.target_scoring = self.output_weighting(self.target_scoring) + self.target_weighted_scoring = self.output_weighting(self.target_scoring) elif data_split == 'test': assert self.target_test is not None - self.target_test = self.output_weighting(self.target_test) + self.target_weighted_test = self.output_weighting(self.target_test) def reweight_preds(self, data_split): '''