diff --git a/t5/evaluation/eval_utils.py b/t5/evaluation/eval_utils.py index 9cd0340d..791eb024 100644 --- a/t5/evaluation/eval_utils.py +++ b/t5/evaluation/eval_utils.py @@ -233,7 +233,7 @@ def metric_group_max(df, metric_names=None): for group, metrics in group_to_metrics.items(): if not all(m in df for m in metrics): continue - group_df[group] = df[metrics].mean(axis=1) + group_df[group] = df[list(metrics)].mean(axis=1) # Need to replace nan with large negative value for idxmax group_max_step = group_df.fillna(-1e9).idxmax(axis=0) metric_max = pd.Series()