Source code for rexmex.utils

from functools import wraps
from typing import Callable, Optional

import numpy as np

__all__ = [
    "Metric",
    "binarize",
    "normalize",
    "Annotator",
]

#: A function that can be called on y_true, y_score and return a floating point result
Metric = Callable[[np.array, np.array], float]


[docs]def binarize(metric): """ Binarize the predictions for a ground-truth - prediction vector pair. Args: metric (function): The metric function which needs a binarization pre-processing step. Returns: metric_wrapper (function): The function which wraps the metric and binarizes the probability scores. """ @wraps(metric) def metric_wrapper(*args, **kwargs): # TODO: Move to optimal binning. Youden’s J statistic. y_score = args[1] y_score[y_score < 0.5] = 0 y_score[y_score >= 0.5] = 1 score = metric(*args, **kwargs) return score return metric_wrapper
[docs]def normalize(metric): """ Normalize the predictions for a ground-truth - prediction vector pair. Args: metric (function): The metric function which needs a normalization pre-processing step. Returns: metric_wrapper (function): The function which wraps the metric and normalizes predictions. """ @wraps(metric) def metric_wrapper(*args, **kwargs): y_true = args[0] y_score = args[1] y_mean = np.mean(y_true) y_std = np.std(y_true) y_true[:] = (y_true - y_mean) / y_std y_score[:] = (y_score - y_mean) / y_std score = metric(*args, **kwargs) return score return metric_wrapper
[docs]class Annotator: """A class to wrap annotations that generates a registry.""" def __init__(self): self.funcs = {} def __iter__(self): return iter(self.funcs.values())
[docs] def annotate( self, *, lower: float, upper: float, higher_is_better: bool, link: str, description: str, name: Optional[str] = None, lower_inclusive: bool = True, upper_inclusive: bool = True, binarize: bool = False, duplicate_of: Optional[Metric] = None, ): """Annotate a function.""" def _wrapper(func): self.funcs[func.__name__] = func func.name = name or func.__name__.replace("_", " ").title() func.lower = lower func.lower_inclusive = lower_inclusive func.upper = upper func.upper_inclusive = upper_inclusive func.higher_is_better = higher_is_better func.link = link func.description = description func.binarize = binarize func.duplicate_of = duplicate_of return func return _wrapper
[docs] def duplicate(self, other, *, name: Optional[str] = None, binarize: Optional[bool] = None): """Annotate a function as a duplicate.""" return self.annotate( name=name, lower=other.lower, lower_inclusive=other.lower_inclusive, upper=other.upper, upper_inclusive=other.upper_inclusive, link=other.link, description=other.description, duplicate_of=other, higher_is_better=other.higher_is_better, # need to be able to override for sklearn functions binarize=binarize if binarize is not None else other.binarize, )