Source code for openadmet.models.eval.classification

"""Classification metrics and plots for model evaluation."""

import json

import matplotlib.pyplot as plt
import numpy as np
import wandb
from pydantic import Field
from sklearn.metrics import (
    accuracy_score,
    auc,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)

from openadmet.models.eval.eval_base import EvalBase, evaluators


[docs]def pr_auc_score(y_true, y_pred): """ Calculate the area under the precision-recall curve. Parameters ---------- y_true : array-like True binary labels or binary label indicators. y_pred : array-like Target scores, probability estimates of the positive class. Returns ------- float Area under the precision-recall curve. """ precision, recall, _ = precision_recall_curve(y_true, y_pred) return auc(recall, precision)
[docs]@evaluators.register("ClassificationMetrics") class ClassificationMetrics(EvalBase): """ Compute and report classification metrics such as accuracy, precision, recall, F1, ROC AUC, and PR AUC. Attributes ---------- bootstrap_confidence_level : float Confidence level for the bootstrap. use_wandb : bool Whether to log metrics to Weights & Biases. _evaluated : bool Whether the evaluation has been performed. _metrics : dict Dictionary of metrics to compute, with metric functions and properties. """ bootstrap_confidence_level: float = Field( 0.95, description="Confidence level for the bootstrap" ) use_wandb: bool = Field(False, description="Whether to use wandb") _evaluated: bool = False _metrics: dict = { "accuracy": (accuracy_score, False, True, "Accuracy"), "precision": (precision_score, False, True, "Precision"), "recall": (recall_score, False, True, "Recall"), "f1": (f1_score, False, True, "F1 Score"), "roc_auc": (roc_auc_score, False, False, "ROC AUC"), "pr_auc": (pr_auc_score, False, False, "PR AUC"), }
[docs] def evaluate(self, y_true=None, y_pred=None, use_wandb=False, tag=None, **kwargs): """ Evaluate the classification model and compute metrics. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted probabilities or class labels. use_wandb : bool, optional Whether to log metrics to Weights & Biases. tag : str, optional Tag for the evaluation run. kwargs: Dict Additional keyword arguments Returns ------- dict Dictionary of computed metrics and confidence intervals. """ if y_true is None or y_pred is None: raise ValueError("Must provide y_true and y_pred") # Cast as numpy arrays y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) self.data = {"tag": tag} if use_wandb: self.use_wandb = use_wandb for metric_tag, (metric, is_scipy, is_class_pred, _) in self._metrics.items(): # Binary case if (y_true.ndim == 1) or (y_true.ndim == 2 and y_true.shape[1] == 1): # Cast to class predictions before calculating the metric if is_class_pred is True: _y_pred = np.argmax(y_pred, axis=1).ravel() _y_true = y_true.ravel() # Compare probabilities with labels else: _y_pred = y_pred[:, 1].ravel() _y_true = y_true.ravel() # Multiclass case else: # Cast to class predictions before calculating the metric if is_class_pred is True: _y_pred = np.argmax(y_pred, axis=1).ravel() _y_true = np.argmax(y_true, axis=1).ravel() # Micro-averaged one-versus-rest else: _y_pred = y_pred.ravel() _y_true = y_true.ravel() value, lower_ci, upper_ci = self.stat_and_bootstrap( metric_tag, _y_pred, _y_true, metric, is_scipy_statistic=is_scipy, confidence_level=self.bootstrap_confidence_level, ) metric_data = {} metric_data["value"] = value metric_data["lower_ci"] = lower_ci metric_data["upper_ci"] = upper_ci metric_data["confidence_level"] = self.bootstrap_confidence_level self.data[f"{metric_tag}"] = metric_data if self.use_wandb: # make a table for the metrics table = wandb.Table( columns=["Metric", "Value", "Lower CI", "Upper CI", "Confidence Level"] ) for metric in self.metric_names: table.add_data( metric, self.data[metric]["value"], self.data[metric]["lower_ci"], self.data[metric]["upper_ci"], self.data[metric]["confidence_level"], ) wandb.log({"metrics": table}) for metric in self.metric_names: wandb.log({metric: self.data[metric]["value"]}) self._evaluated = True return self.data
@property def metric_names(self): """ Return the metric names. Returns ------- list of str List of metric names. """ return list(self._metrics.keys())
[docs] def report(self, write=False, output_dir=None): """ Report the evaluation results, optionally writing to disk. Parameters ---------- write : bool, optional Whether to write the report to disk. output_dir : str, optional Output directory for the report. Returns ------- dict Dictionary of computed metrics. """ if write: self.write_report(output_dir) return self.data
[docs] def write_report(self, output_dir): """ Write the evaluation report to a JSON file and optionally log to wandb. Parameters ---------- output_dir : str Output directory for the report. """ # write to JSON json_path = output_dir / "classification_metrics.json" with open(json_path, "w") as f: json.dump(self.data, f, indent=2) # also log the json to wandb if self.use_wandb: artifact = wandb.Artifact(name="metrics_json", type="metric_json") # Add a file to the artifact artifact.add_file(json_path) # Log the artifact wandb.log_artifact(artifact)
[docs]@evaluators.register("ClassificationPlots") class ClassificationPlots(EvalBase): """ Generate and save classification plots such as ROC and PR curves. Attributes ---------- plots : dict Dictionary of plot functions. use_wandb : bool Whether to log plots to Weights & Biases. dpi : int DPI for the plots. """ plots: dict = {} use_wandb: bool = Field(False, description="Whether to use wandb") dpi: int = Field(300, description="DPI for the plot")
[docs] def evaluate(self, y_true=None, y_pred=None, use_wandb=False, **kwargs): """ Generate classification plots. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted probabilities or class labels. use_wandb : bool, optional Whether to log plots to Weights & Biases. kwargs : Dict Additional keyword arguments. """ # Cast as numpy arrays y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) if use_wandb: self.use_wandb = use_wandb if y_true is None or y_pred is None: raise ValueError("Must provide y_true and y_pred") self.plots = { "roc_curve": self.roc_curve, "pr_curve": self.pr_curve, } self.plot_data = {} # Create the plots for plot_tag, plot in self.plots.items(): self.plot_data[plot_tag] = plot( y_true, y_pred, )
[docs] def roc_curve( self, y_true, y_pred, xlabel="False Positive Rate", ylabel="True Positive Rate", title="Receiver Operating Characteristic Curve", ): """ Plot the ROC curve. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted probabilities or class labels. xlabel : str, optional Label for the x-axis. ylabel : str, optional Label for the y-axis. title : str, optional Title for the plot. Returns ------- matplotlib.figure.Figure The ROC curve figure. """ # Binary if (y_true.ndim == 1) or (y_true.ndim == 2 and y_true.shape[1] == 1): fpr, tpr, _ = roc_curve(y_true.ravel(), y_pred[:, 1].ravel()) # Micro-averaged one-versus-rest else: fpr, tpr, _ = roc_curve(y_true.ravel(), y_pred.ravel()) fig, ax = plt.subplots(dpi=self.dpi) ax.set_title(title, fontsize=10) ax.plot(fpr, tpr) ax.plot([0, 1], [0, 1], linestyle="--", color="black") ax.set_aspect("equal", "box") ax.set_xlabel(xlabel, fontsize=10) ax.set_ylabel(ylabel, fontsize=10) return fig
[docs] def pr_curve( self, y_true, y_pred, xlabel="Recall", ylabel="Precision", title="Precision-Recall Curve", ): """ Plot the precision-recall curve. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted probabilities or class labels. xlabel : str, optional Label for the x-axis. ylabel : str, optional Label for the y-axis. title : str, optional Title for the plot. Returns ------- matplotlib.figure.Figure The precision-recall curve figure. """ # Binary if (y_true.ndim == 1) or (y_true.ndim == 2 and y_true.shape[1] == 1): precision, recall, _ = precision_recall_curve( y_true.ravel(), y_pred[:, 1].ravel() ) # Micro-averaged one-versus-rest else: precision, recall, _ = precision_recall_curve( y_true.ravel(), y_pred.ravel() ) fig, ax = plt.subplots(dpi=self.dpi) ax.set_title(title, fontsize=10) ax.plot(recall, precision) ax.plot([0, 1], [1, 1], linestyle="--", color="black") ax.plot([1, 1], [0, 1], linestyle="--", color="black") ax.set_aspect("equal", "box") ax.set_xlabel(xlabel, fontsize=10) ax.set_ylabel(ylabel, fontsize=10) return fig
[docs] def report(self, write=False, output_dir=None): """ Report the generated plots, optionally writing to disk. Parameters ---------- write : bool, optional Whether to write the plots to disk. output_dir : str, optional Output directory for the plots. Returns ------- dict Dictionary of plot figures. """ if write: self.write_report(output_dir) return self.plot_data
[docs] def write_report(self, output_dir): """ Write the generated plots to PNG files and optionally log to wandb. Parameters ---------- output_dir : str Output directory for the plots. """ for plot_tag, plot in self.plot_data.items(): plot_path = output_dir / f"{plot_tag}.png" plot.savefig(plot_path, dpi=self.dpi) if self.use_wandb: wandb.log({plot_tag: wandb.Image(str(plot_path))})