Source code for openadmet.models.eval.regression

"""Regression metrics and plots for model evaluation."""

import json

import numpy as np
import pandas as pd
from pydantic import Field

from openadmet.models.eval.eval_base import (
    EvalBase,
    evaluators,
    get_t_true_and_t_pred,
)
from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict


[docs]def relative_absolute_error(y_true, y_pred): """ Compute Relative Absolute Error (RAE). RAE = sum(|y_true - y_pred|) / sum(|y_true - mean(y_true)|). Lower is better; RAE < 1.0 means the model outperforms a naive mean predictor. Parameters ---------- y_true : array-like True values. y_pred : array-like Predicted values. Returns ------- float Relative absolute error. """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) numerator = np.sum(np.abs(y_true - y_pred)) denominator = np.sum(np.abs(y_true - np.mean(y_true))) if denominator == 0: return np.nan return numerator / denominator
[docs]def pct_within_1_log_unit(y_true, y_pred): """ Compute the fraction of predictions within +/-1 log unit of the true value. Parameters ---------- y_true : array-like True values (assumed to be on a log scale, e.g. pXC50). y_pred : array-like Predicted values. Returns ------- float Fraction (0-1) of predictions within 1 log unit. """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) return np.mean(np.abs(y_true - y_pred) <= 1.0)
[docs]@evaluators.register("RegressionMetrics") class RegressionMetrics(EvalBase): """ Compute and report regression metrics such as MSE, MAE, R2, Kendall's tau, and Spearman's rho. Attributes ---------- bootstrap_confidence_level : float Confidence level for the bootstrap. use_wandb : bool Whether to use wandb for logging. _evaluated : bool Whether the model has been evaluated. _metrics : dict Dictionary of metrics to compute. """ bootstrap_confidence_level: float = Field( 0.95, description="Confidence level for the bootstrap" ) use_wandb: bool = Field(False, description="Whether to use wandb") pXC50: bool = Field( False, description="Whether targets are in pXC50/log units for log-based metrics", ) _evaluated: bool = False @classmethod def _base_metrics(cls) -> dict: """ Build the base metrics dictionary with deferred 3rd-party imports. Returns ------- dict Mapping of metric key to ``(callable, is_scipy_statistic, display_label)`` tuples for MSE, MAE, R², Kendall's τ, Spearman's ρ, and RAE. """ from functools import partial from scipy.stats import kendalltau, spearmanr from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score nan_omit_ktau = partial(kendalltau, nan_policy="omit") nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") return { "mse": (mean_squared_error, False, "MSE"), "mae": (mean_absolute_error, False, "MAE"), "r2": (r2_score, False, "$R^2$"), "ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"), "spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"), "rae": (relative_absolute_error, False, "RAE"), } @property def active_metrics(self): """Return metrics applicable to the current target scale.""" metrics = self._base_metrics() if self.pXC50: metrics["pct_within_1_log"] = ( pct_within_1_log_unit, False, "Fraction within ±1 log", ) return metrics
[docs] def evaluate( self, y_true=None, y_pred=None, use_wandb=False, tag=None, target_labels=None, **kwargs, ): """ Evaluate the regression model and compute metrics. Parameters ---------- y_true : array-like True values. y_pred : array-like Predicted values. use_wandb : bool, optional Whether to log metrics to Weights & Biases. tag : str, optional Tag for the evaluation run. target_labels : list of str, optional List of target names. kwargs : Dict Additional keyword arguments. Returns ------- dict Dictionary of computed metrics and confidence intervals. """ if y_true is None or y_pred is None: raise ValueError("Must provide y_true and y_pred") if isinstance(y_true, (pd.Series, pd.DataFrame)): y_true = y_true.to_numpy() # Ensure y_pred and y_true are 2D arrays for consistency if y_pred.ndim == 1: y_pred = y_pred.reshape(-1, 1) if y_true.ndim == 1: y_true = y_true.reshape(-1, 1) n_tasks = y_true.shape[1] if not (n_tasks == y_pred.shape[1]): raise ValueError("y_true and y_pred must have the same number of tasks") if target_labels is None: target_labels = [f"task_{i}" for i in range(n_tasks)] self.data = {"tag": tag} if use_wandb: self.use_wandb = use_wandb for task_id in range(n_tasks): t_true, t_pred = get_t_true_and_t_pred(task_id, y_true, y_pred, None, None) t_label = target_labels[task_id] self.data[t_label] = {} for metric_tag, (metric, is_scipy, _) in self.active_metrics.items(): value, lower_ci, upper_ci = self.stat_and_bootstrap( metric_tag, t_pred, t_true, metric, is_scipy_statistic=is_scipy, confidence_level=self.bootstrap_confidence_level, ) self.data[t_label][metric_tag] = { "value": value, "lower_ci": lower_ci, "upper_ci": upper_ci, "confidence_level": self.bootstrap_confidence_level, } if self.use_wandb: import wandb for t_label in target_labels: # make a table for the metrics table = wandb.Table( columns=[ "Metric", "Value", "Lower CI", "Upper CI", "Confidence Level", ] ) for metric_tag in self.metric_names: metric = self.data[t_label][metric_tag] table.add_data( metric_tag, metric["value"], metric["lower_ci"], metric["upper_ci"], metric["confidence_level"], ) wandb.log({f"metrics_{t_label}": table}) self._evaluated = True return self.data
@property def metric_names(self): """ Return the metric names. Returns ------- list of str List of metric names. """ return list(self.active_metrics.keys()) @property def task_names(self): """ Return the task names. Returns ------- list of str List of task names. """ return list(self.data.keys())
[docs] def report(self, write=False, output_dir=None): """ Report the evaluation results, optionally writing to disk. Parameters ---------- write : bool, optional Whether to write the report to disk. output_dir : str, optional Output directory for the report. Returns ------- dict Dictionary of computed metrics. """ if write: self.write_report(output_dir) return self.data
[docs] def write_report(self, output_dir): """ Write the evaluation report to a JSON file and optionally log to wandb. Parameters ---------- output_dir : str Output directory for the report. """ # write to JSON json_path = output_dir / "regression_metrics.json" with open(json_path, "w") as f: json.dump(self.data, f, indent=2) # also log the json to wandb if self.use_wandb: import wandb artifact = wandb.Artifact(name="metrics_json", type="metric_json") # Add a file to the artifact artifact.add_file(json_path) # Log the artifact wandb.log_artifact(artifact)
[docs] def get_stat_caption(self, t_label): """ Get a formatted statistics caption for a given task. Parameters ---------- t_label : str Task label. Returns ------- str Caption string with statistics. """ if not self._evaluated: raise ValueError( ":( You must evaluate the model before the statistics caption can be made." ) return _make_stat_caption( data=self.data, task_name=t_label, metric_names=self.metric_names, metrics=self.active_metrics, confidence_level=self.bootstrap_confidence_level, cv=False, )
[docs] def get_stat_dict(self, t_label): """ Get a statistics dictionary for a given task. Parameters ---------- t_label : str Task label. Returns ------- dict Dictionary of statistics for the task. """ if not self._evaluated: raise ValueError( "R'uh-r'oh! You must evaluate the model before the statistics dict can be made." ) return _make_stat_dict( data=self.data, task_name=t_label, metric_names=self.metric_names, metrics=self._base_metrics(), confidence_level=self.bootstrap_confidence_level, cv=False, )
[docs]@evaluators.register("RegressionPlots") class RegressionPlots(EvalBase): """ Generate and save regression plots such as regression scatter plots and confidence interval plots. Attributes ---------- axes_labels : list of str Labels for the axes. title : str Title for the plot. do_stats : bool Whether to compute and display statistics on the plots. pXC50 : bool Whether to highlight pXC50 log unit ranges. plot_errbars : bool Whether to plot error bars for ensemble predictions. plots : dict Dictionary of plot functions. min_val : float Minimum value for the axes. max_val : float Maximum value for the axes. use_wandb : bool Whether to use wandb for logging. dpi : int DPI for the plot. """ axes_labels: list[str] = Field( ["Measured", "Predicted"], description="Labels for the axes" ) title: str = Field("Pred vs ", description="Title for the plot") do_stats: bool = Field(True, description="Whether to do stats for the plot") pXC50: bool = Field( False, description="Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit", ) plot_errbars: bool = Field( False, description="Whether to plot error bars for ensemble predictions" ) plots: dict = {} min_val: float = Field(None, description="Minimum value for the axes") max_val: float = Field(None, description="Maximum value for the axes") fit_reg: bool = Field(False, description="Whether to fit regression line") use_wandb: bool = Field(False, description="Whether to use wandb") dpi: int = Field(300, description="DPI for the plot")
[docs] def evaluate( self, y_true=None, y_pred=None, y_std=None, use_wandb=False, target_labels=None, **kwargs, ): """ Generate regression plots and optionally compute statistics. Parameters ---------- y_true : array-like True values. y_pred : array-like Predicted values. y_std : array-like Standard deviation of predictions if ensemble is specified. use_wandb : bool, optional Whether to log plots to Weights & Biases. target_labels : list of str, optional List of target names. kwargs : Dict Additional keyword arguments. Returns ------- dict Dictionary of plot figures. """ if use_wandb: self.use_wandb = use_wandb if y_true is None or y_pred is None: raise ValueError("Must provide y_true and y_pred") if isinstance(y_true, (pd.Series, pd.DataFrame)): y_true = y_true.to_numpy() # Ensure y_pred and y_true are 2D arrays for consistency if y_pred.ndim == 1: y_pred = y_pred.reshape(-1, 1) if y_true.ndim == 1: y_true = y_true.reshape(-1, 1) n_tasks = y_true.shape[1] if not (n_tasks == y_pred.shape[1]): raise ValueError("y_true and y_pred must have the same number of tasks") if target_labels is None: target_labels = [f"task_{i}" for i in range(n_tasks)] self.plots = {"regplot": self.regplot, "ciplot": self.ciplot} self.plot_data = {} for task_id in range(n_tasks): t_true, t_pred = get_t_true_and_t_pred(task_id, y_true, y_pred, None, None) t_label = target_labels[task_id] if self.do_stats: rm = RegressionMetrics(n_resamples=self.n_resamples) rm.evaluate( t_true.reshape(-1, 1), t_pred.reshape(-1, 1), target_labels=[t_label], ) stat_dict = rm.get_stat_dict(t_label=t_label) else: stat_dict = {} # create the plots for plot_tag, plot in self.plots.items(): if "ciplot" in plot_tag: self.plot_data[f"{t_label}_{plot_tag}"] = plot(stat_dict=stat_dict) elif "regplot" in plot_tag: self.plot_data[f"{t_label}_{plot_tag}"] = plot( t_true, t_pred, y_pred_err=y_std, xlabel=self.axes_labels[0], ylabel=self.axes_labels[1], title=f"{self.title}\nTask: {t_label}", stat_dict=stat_dict, pXC50=self.pXC50, min_val=self.min_val, max_val=self.max_val, plot_errbars=self.plot_errbars, fit_reg=self.fit_reg, ) return self.plot_data
[docs] @staticmethod def regplot( y_true, y_pred, y_pred_err=None, y_true_err=None, data_labels=None, xlabel="Measured", ylabel="Predicted", title="", stat_dict={}, confidence_level=0.95, pXC50=False, min_val=None, max_val=None, fit_reg=False, plot_errbars=False, ): """ Create a regression scatter plot with optional confidence intervals and statistics table. Parameters ---------- y_true : array-like True values. y_pred : array-like Predicted values. y_pred_err : array-like, optional Prediction error bars. y_true_err: array-like, optional Experimental error bars. data_labels : list, optional Labels for each data point. xlabel : str, optional Label for the x-axis. ylabel : str, optional Label for the y-axis. title : str, optional Title for the plot. stat_dict : dict, optional Dictionary of statistics to display on the plot. confidence_level : float, optional Confidence level for the regression line. pXC50 : bool, optional Whether to highlight pXC50 log unit ranges. min_val : float, optional Minimum axis value. max_val : float, optional Maximum axis value. fit_reg : bool, optional Whether to fit and plot a regression line. plot_errbars : bool, optional Whether to plot model error bars from ensemble predictions. Returns ------- seaborn.axisgrid.JointGrid The regression plot object. """ title_font = 20 ax_font = 18 tick_font = 16 if min_val is None: min_val = min(np.min(y_true), np.min(y_pred)) min_ax = min_val - 1 else: min_ax = min_val if max_val is None: max_val = max(np.max(y_true), np.max(y_pred)) max_ax = max_val + 1 else: max_ax = max_val # set the limits to be the same for both axes import seaborn as sns g = sns.jointplot( x=np.ravel(y_true), y=np.ravel(y_pred), kind="reg", joint_kws={"ci": confidence_level * 100, "fit_reg": fit_reg}, color="teal", height=10, scatter_kws={"alpha": 0.3}, ) if y_pred_err is not None and plot_errbars: g.ax_joint.errorbar( x=np.ravel(y_true), y=np.ravel(y_pred), yerr=np.ravel(y_pred_err), fmt="o", color="teal", alpha=0.3, ) if y_true_err is not None and plot_errbars: g.ax_joint.errorbar( x=np.ravel(y_true), y=np.ravel(y_pred), xerr=np.ravel(y_true_err), fmt="o", color="teal", alpha=0.3, ) if data_labels is not None: for i, label in enumerate(data_labels): g.ax_joint.text( x=np.ravel(y_true)[i], y=np.ravel(y_pred)[i], s=label, fontsize=6, color="black", ha="right", va="bottom", ) g.figure.suptitle(title, fontsize=title_font) g.ax_joint.set_aspect("equal", "box") g.ax_joint.set_xlim(min_ax, max_ax) g.ax_joint.set_ylim(min_ax, max_ax) g.ax_joint.tick_params(axis="both", labelsize=tick_font) # plot y = x line in dashed grey g.ax_joint.plot( [min_ax, max_ax], [min_ax, max_ax], linestyle="--", color="black" ) # if pXC50 measure then plot the 0.5 and 1.0 log range unit if pXC50: g.ax_joint.fill_between( [min_ax, max_ax], [min_ax - 0.5, max_ax - 0.5], [min_ax + 0.5, max_ax + 0.5], color="gray", alpha=0.2, ) g.ax_joint.fill_between( [min_ax, max_ax], [min_ax - 1, max_ax - 1], [min_ax + 1, max_ax + 1], color="gray", alpha=0.2, ) g.ax_joint.set_xlabel(xlabel, fontsize=ax_font) g.ax_joint.set_ylabel(ylabel, fontsize=ax_font) # From the stat_dict, parse out the performance metric values and their labels to put into a table to print on the regression plot if stat_dict: conf_level = stat_dict.get("conf_level", None) metric_names = stat_dict.get("metrics", []) values = stat_dict.get("means", []) lower_bounds = stat_dict.get("lower_ci", []) upper_bounds = stat_dict.get("upper_ci", []) table_data = [] # Format the metric values for readability for name, val, low, high in zip( metric_names, values, lower_bounds, upper_bounds ): if None not in (val, low, high): val_str = f"{val:.2f} [{low:.2f}, {high:.2f}]" else: val_str = "N/A" table_data.append([name, val_str]) # Create the table table = g.ax_joint.table( cellText=table_data, colLabels=["Metric", f"Value ± {int(conf_level * 100)}% CI"], colWidths=[0.2, 0.3], loc="upper left", cellLoc="left", ) table.scale(1, 1.8) for key, cell in table.get_celld().items(): cell.set_fontsize(ax_font) # Right align the metric values for i in range(1, len(table_data) + 1): table[i, 1].get_text().set_horizontalalignment("right") g.ax_joint.set_box_aspect(1) g.figure.tight_layout() return g
[docs] @staticmethod def ciplot(stat_dict={}): """ Create a confidence interval plot for regression metrics. Parameters ---------- stat_dict : dict Dictionary containing metrics, means, confidence intervals, and task name. Returns ------- matplotlib.figure.Figure The confidence interval plot figure. """ metrics = stat_dict["metrics"] means = stat_dict["means"] lower_ci = stat_dict["lower_ci"] upper_ci = stat_dict["upper_ci"] conf_level = stat_dict["conf_level"] task_name = stat_dict["task_name"] title_font = 16 tick_font = 12 ax_font = 14 y_limits = { "$R^2$": (0, 1), "Kendall's $\\tau$": (0, 1), "Spearman's $\\rho$": (0, 1), } n_metrics = len(metrics) from matplotlib import pyplot as plt fig, axes = plt.subplots(1, n_metrics, figsize=(8, n_metrics), sharex=False) if n_metrics == 1: axes = [axes] # Ensure it's iterable for i, ax in enumerate(axes): if i == 0: ax.set_ylabel("Performance Metric Value", fontsize=ax_font) metric = metrics[i] y = means[i] yerr = [[y - lower_ci[i]], [upper_ci[i] - y]] ax.errorbar([metric], [y], yerr=yerr, fmt="o", capsize=8, color="green") ax.tick_params(axis="both", labelsize=tick_font) ax.yaxis.grid(True, linestyle="--", color="lightgray", alpha=0.6) ax.set_xlim(-0.5, 0.5) # Set fixed y-limits if metric in ["MSE", "MAE"]: upper = upper_ci[i] * 1.1 if upper_ci[i] > 0 else 1 ax.set_ylim(0, upper) elif metric in y_limits: ax.set_ylim(y_limits[metric]) fig.suptitle( f"Evaluation of {task_name} with {int(conf_level * 100)}% Confidence Intervals", fontsize=title_font, ) fig.tight_layout() return fig
[docs] def report(self, write=False, output_dir=None): """ Report the generated plots, optionally writing to disk. Parameters ---------- write : bool, optional Whether to write the plots to disk. output_dir : str, optional Output directory for the plots. Returns ------- dict Dictionary of plot figures. """ if write: self.write_report(output_dir) return self.plot_data
[docs] def write_report(self, output_dir): """ Write the generated plots to PNG files and optionally log to wandb. Parameters ---------- output_dir : str Output directory for the plots. """ for plot_tag, plot in self.plot_data.items(): plot_path = output_dir / f"{plot_tag}.png" plot.savefig(plot_path, dpi=self.dpi) if self.use_wandb: import wandb wandb.log({plot_tag: wandb.Image(str(plot_path))})