Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero_division to F1 metric #606

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion metrics/f1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ At minimum, this metric requires predictions and references as input
- 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall.
- 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification).
- **sample_weight** (`list` of `float`): Sample weights Defaults to None.
- **zero_division** ('warn' or 0.0 or 1.0 or np.nan): Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. Defaults to 'warn'.


### Output Values
Expand Down Expand Up @@ -134,4 +135,4 @@ Example 4-A multiclass example, with different values for the `average` input.
```


## Further References
## Further References
5 changes: 3 additions & 2 deletions metrics/f1/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
- 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall.
- 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification).
sample_weight (`list` of `float`): Sample weights Defaults to None.
zero_division ('warn' or 0.0 or 1.0 or np.nan): Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. Defaults to 'warn'.

Returns:
f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better.
Expand Down Expand Up @@ -123,8 +124,8 @@ def _info(self):
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
)

def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", zero_division=None, sample_weight=None):
score = f1_score(
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=zero_division,
)
return {"f1": float(score) if score.size == 1 else score}