from typing import Collection, List, Tuple
from rexmex.metrics.classification import classifications
from rexmex.metrics.coverage import item_coverage, user_coverage
from rexmex.metrics.rating import (
mean_absolute_error,
mean_absolute_percentage_error,
mean_squared_error,
pearson_correlation_coefficient,
r2_score,
root_mean_squared_error,
symmetric_mean_absolute_percentage_error,
)
from rexmex.utils import binarize, normalize
[docs]class MetricSet(dict):
"""
A metric set is a special dictionary that contains metric
name keys and evaluation metric function values.
"""
[docs] def filter_metrics(self, filter: Collection[str]):
"""
A method to keep a list of metrics.
Args:
filter: A list of metric names to keep.
Returns:
self: The metric set after the metrics were filtered out.
"""
for name in list(self.keys()):
if name not in filter:
del self[name]
return self
[docs] def add_metrics(self, metrics: List[Tuple]):
"""
A method to add metric functions from a list of function names and functions.
Args:
metrics (List[Tuple]): A list of metric name and metric function tuples.
Returns:
self: The metric set after the metrics were added.
"""
for metric in metrics:
metric_name, metric_function = metric
self[metric_name] = metric_function
return self
def __repr__(self):
"""
A representation of the MetricSet object.
"""
return "MetricSet()"
[docs] def print_metrics(self):
"""
Printing the name of metrics.
"""
print({k for k in self.keys()})
def __add__(self, other_metric_set):
"""
Adding two metric sets together with the addition syntactic sugar operator.
Args:
other_metric_set (rexmex.metricset.MetricSet): Metric set added from the right.
Returns:
new_metric_set (rexmex.metricset.MetricSet): The combined metric set.
"""
new_metric_set = self
for name, metric in other_metric_set.items():
new_metric_set[name] = metric
return new_metric_set
[docs]class ClassificationMetricSet(MetricSet):
"""
A set of classification metrics with the following metrics included:
| **Area Under the Receiver Operating Characteristic Curve**
| **Area Under the Precision Recall Curve**
| **Average Precision**
| **F-1 Score**
| **Matthew's Correlation Coefficient**
| **Fowlkes-Mallows Index**
| **Precision**
| **Recall**
| **Specificity**
| **Accuracy**
| **Balanced Accuracy**
"""
def __init__(self):
super().__init__()
for func in classifications:
name = func.__name__
if name.endswith("_score"):
name = name[: -len("_score")]
if func.binarize:
func = binarize(func)
self[name] = func
def __repr__(self):
"""
A representation of the ClassificationMetricSet object.
"""
return "ClassificationMetricSet()"
[docs]class RatingMetricSet(MetricSet):
"""
A set of rating metrics with the following metrics included:
| **Mean Absolute Error**
| **Mean Squared Error**
| **Root Mean Squared Error**
| **Mean Absolute Percentage Error**
| **Symmetric Mean Absolute Percentage Error**
| **Coefficient of Determination**
| **Pearson Correlation Coefficient**
"""
def __init__(self):
self["mae"] = mean_absolute_error
self["mse"] = mean_squared_error
self["rmse"] = root_mean_squared_error
self["mape"] = mean_absolute_percentage_error
self["smape"] = symmetric_mean_absolute_percentage_error
self["r_squared"] = r2_score
self["pearson_correlation"] = pearson_correlation_coefficient
[docs] def normalize_metrics(self):
"""
A method to normalize a set of metrics.
Returns:
self: The metric set after the metrics were normalized.
"""
for name, metric in self.items():
self[name] = normalize(metric)
return self
def __repr__(self):
"""
A representation of the RatingMetricSet object.
"""
return "RatingMetricSet()"
[docs]class CoverageMetricSet(MetricSet):
"""
A set of coverage metrics with the following metrics included:
| **Item Coverage**
| **User Coverage**
"""
def __init__(self):
self["item_coverage"] = item_coverage
self["user_coverage"] = user_coverage
def __repr__(self):
"""
A representation of the CoverageMetricSet object.
"""
return "CoverageMetricSet()"
[docs]class RankingMetricSet(MetricSet):
"""
A set of ranking metrics with the following metrics included:
"""
def __repr__(self):
"""
A representation of the RankingMetricSet object.
"""
return "RankingMetricSet()"