"""Regression metrics and plots for model evaluation."""
import json
import numpy as np
import pandas as pd
from pydantic import Field
from openadmet.models.eval.eval_base import (
EvalBase,
evaluators,
get_t_true_and_t_pred,
)
from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict
[docs]def relative_absolute_error(y_true, y_pred):
"""
Compute Relative Absolute Error (RAE).
RAE = sum(|y_true - y_pred|) / sum(|y_true - mean(y_true)|).
Lower is better; RAE < 1.0 means the model outperforms a naive
mean predictor.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
Returns
-------
float
Relative absolute error.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
numerator = np.sum(np.abs(y_true - y_pred))
denominator = np.sum(np.abs(y_true - np.mean(y_true)))
if denominator == 0:
return np.nan
return numerator / denominator
[docs]def pct_within_1_log_unit(y_true, y_pred):
"""
Compute the fraction of predictions within +/-1 log unit of the true value.
Parameters
----------
y_true : array-like
True values (assumed to be on a log scale, e.g. pXC50).
y_pred : array-like
Predicted values.
Returns
-------
float
Fraction (0-1) of predictions within 1 log unit.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
return np.mean(np.abs(y_true - y_pred) <= 1.0)
[docs]@evaluators.register("RegressionMetrics")
class RegressionMetrics(EvalBase):
"""
Compute and report regression metrics such as MSE, MAE, R2, Kendall's tau, and Spearman's rho.
Attributes
----------
bootstrap_confidence_level : float
Confidence level for the bootstrap.
use_wandb : bool
Whether to use wandb for logging.
_evaluated : bool
Whether the model has been evaluated.
_metrics : dict
Dictionary of metrics to compute.
"""
bootstrap_confidence_level: float = Field(
0.95, description="Confidence level for the bootstrap"
)
use_wandb: bool = Field(False, description="Whether to use wandb")
pXC50: bool = Field(
False,
description="Whether targets are in pXC50/log units for log-based metrics",
)
_evaluated: bool = False
@classmethod
def _base_metrics(cls) -> dict:
"""
Build the base metrics dictionary with deferred 3rd-party imports.
Returns
-------
dict
Mapping of metric key to ``(callable, is_scipy_statistic, display_label)``
tuples for MSE, MAE, R², Kendall's τ, Spearman's ρ, and RAE.
"""
from functools import partial
from scipy.stats import kendalltau, spearmanr
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
nan_omit_ktau = partial(kendalltau, nan_policy="omit")
nan_omit_spearmanr = partial(spearmanr, nan_policy="omit")
return {
"mse": (mean_squared_error, False, "MSE"),
"mae": (mean_absolute_error, False, "MAE"),
"r2": (r2_score, False, "$R^2$"),
"ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"),
"spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"),
"rae": (relative_absolute_error, False, "RAE"),
}
@property
def active_metrics(self):
"""Return metrics applicable to the current target scale."""
metrics = self._base_metrics()
if self.pXC50:
metrics["pct_within_1_log"] = (
pct_within_1_log_unit,
False,
"Fraction within ±1 log",
)
return metrics
[docs] def evaluate(
self,
y_true=None,
y_pred=None,
use_wandb=False,
tag=None,
target_labels=None,
**kwargs,
):
"""
Evaluate the regression model and compute metrics.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
use_wandb : bool, optional
Whether to log metrics to Weights & Biases.
tag : str, optional
Tag for the evaluation run.
target_labels : list of str, optional
List of target names.
kwargs : Dict
Additional keyword arguments.
Returns
-------
dict
Dictionary of computed metrics and confidence intervals.
"""
if y_true is None or y_pred is None:
raise ValueError("Must provide y_true and y_pred")
if isinstance(y_true, (pd.Series, pd.DataFrame)):
y_true = y_true.to_numpy()
# Ensure y_pred and y_true are 2D arrays for consistency
if y_pred.ndim == 1:
y_pred = y_pred.reshape(-1, 1)
if y_true.ndim == 1:
y_true = y_true.reshape(-1, 1)
n_tasks = y_true.shape[1]
if not (n_tasks == y_pred.shape[1]):
raise ValueError("y_true and y_pred must have the same number of tasks")
if target_labels is None:
target_labels = [f"task_{i}" for i in range(n_tasks)]
self.data = {"tag": tag}
if use_wandb:
self.use_wandb = use_wandb
for task_id in range(n_tasks):
t_true, t_pred = get_t_true_and_t_pred(task_id, y_true, y_pred, None, None)
t_label = target_labels[task_id]
self.data[t_label] = {}
for metric_tag, (metric, is_scipy, _) in self.active_metrics.items():
value, lower_ci, upper_ci = self.stat_and_bootstrap(
metric_tag,
t_pred,
t_true,
metric,
is_scipy_statistic=is_scipy,
confidence_level=self.bootstrap_confidence_level,
)
self.data[t_label][metric_tag] = {
"value": value,
"lower_ci": lower_ci,
"upper_ci": upper_ci,
"confidence_level": self.bootstrap_confidence_level,
}
if self.use_wandb:
import wandb
for t_label in target_labels:
# make a table for the metrics
table = wandb.Table(
columns=[
"Metric",
"Value",
"Lower CI",
"Upper CI",
"Confidence Level",
]
)
for metric_tag in self.metric_names:
metric = self.data[t_label][metric_tag]
table.add_data(
metric_tag,
metric["value"],
metric["lower_ci"],
metric["upper_ci"],
metric["confidence_level"],
)
wandb.log({f"metrics_{t_label}": table})
self._evaluated = True
return self.data
@property
def metric_names(self):
"""
Return the metric names.
Returns
-------
list of str
List of metric names.
"""
return list(self.active_metrics.keys())
@property
def task_names(self):
"""
Return the task names.
Returns
-------
list of str
List of task names.
"""
return list(self.data.keys())
[docs] def report(self, write=False, output_dir=None):
"""
Report the evaluation results, optionally writing to disk.
Parameters
----------
write : bool, optional
Whether to write the report to disk.
output_dir : str, optional
Output directory for the report.
Returns
-------
dict
Dictionary of computed metrics.
"""
if write:
self.write_report(output_dir)
return self.data
[docs] def write_report(self, output_dir):
"""
Write the evaluation report to a JSON file and optionally log to wandb.
Parameters
----------
output_dir : str
Output directory for the report.
"""
# write to JSON
json_path = output_dir / "regression_metrics.json"
with open(json_path, "w") as f:
json.dump(self.data, f, indent=2)
# also log the json to wandb
if self.use_wandb:
import wandb
artifact = wandb.Artifact(name="metrics_json", type="metric_json")
# Add a file to the artifact
artifact.add_file(json_path)
# Log the artifact
wandb.log_artifact(artifact)
[docs] def get_stat_caption(self, t_label):
"""
Get a formatted statistics caption for a given task.
Parameters
----------
t_label : str
Task label.
Returns
-------
str
Caption string with statistics.
"""
if not self._evaluated:
raise ValueError(
":( You must evaluate the model before the statistics caption can be made."
)
return _make_stat_caption(
data=self.data,
task_name=t_label,
metric_names=self.metric_names,
metrics=self.active_metrics,
confidence_level=self.bootstrap_confidence_level,
cv=False,
)
[docs] def get_stat_dict(self, t_label):
"""
Get a statistics dictionary for a given task.
Parameters
----------
t_label : str
Task label.
Returns
-------
dict
Dictionary of statistics for the task.
"""
if not self._evaluated:
raise ValueError(
"R'uh-r'oh! You must evaluate the model before the statistics dict can be made."
)
return _make_stat_dict(
data=self.data,
task_name=t_label,
metric_names=self.metric_names,
metrics=self._base_metrics(),
confidence_level=self.bootstrap_confidence_level,
cv=False,
)
[docs]@evaluators.register("RegressionPlots")
class RegressionPlots(EvalBase):
"""
Generate and save regression plots such as regression scatter plots and confidence interval plots.
Attributes
----------
axes_labels : list of str
Labels for the axes.
title : str
Title for the plot.
do_stats : bool
Whether to compute and display statistics on the plots.
pXC50 : bool
Whether to highlight pXC50 log unit ranges.
plot_errbars : bool
Whether to plot error bars for ensemble predictions.
plots : dict
Dictionary of plot functions.
min_val : float
Minimum value for the axes.
max_val : float
Maximum value for the axes.
use_wandb : bool
Whether to use wandb for logging.
dpi : int
DPI for the plot.
"""
axes_labels: list[str] = Field(
["Measured", "Predicted"], description="Labels for the axes"
)
title: str = Field("Pred vs ", description="Title for the plot")
do_stats: bool = Field(True, description="Whether to do stats for the plot")
pXC50: bool = Field(
False,
description="Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit",
)
plot_errbars: bool = Field(
False, description="Whether to plot error bars for ensemble predictions"
)
plots: dict = {}
min_val: float = Field(None, description="Minimum value for the axes")
max_val: float = Field(None, description="Maximum value for the axes")
fit_reg: bool = Field(False, description="Whether to fit regression line")
use_wandb: bool = Field(False, description="Whether to use wandb")
dpi: int = Field(300, description="DPI for the plot")
[docs] def evaluate(
self,
y_true=None,
y_pred=None,
y_std=None,
use_wandb=False,
target_labels=None,
**kwargs,
):
"""
Generate regression plots and optionally compute statistics.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
y_std : array-like
Standard deviation of predictions if ensemble is specified.
use_wandb : bool, optional
Whether to log plots to Weights & Biases.
target_labels : list of str, optional
List of target names.
kwargs : Dict
Additional keyword arguments.
Returns
-------
dict
Dictionary of plot figures.
"""
if use_wandb:
self.use_wandb = use_wandb
if y_true is None or y_pred is None:
raise ValueError("Must provide y_true and y_pred")
if isinstance(y_true, (pd.Series, pd.DataFrame)):
y_true = y_true.to_numpy()
# Ensure y_pred and y_true are 2D arrays for consistency
if y_pred.ndim == 1:
y_pred = y_pred.reshape(-1, 1)
if y_true.ndim == 1:
y_true = y_true.reshape(-1, 1)
n_tasks = y_true.shape[1]
if not (n_tasks == y_pred.shape[1]):
raise ValueError("y_true and y_pred must have the same number of tasks")
if target_labels is None:
target_labels = [f"task_{i}" for i in range(n_tasks)]
self.plots = {"regplot": self.regplot, "ciplot": self.ciplot}
self.plot_data = {}
for task_id in range(n_tasks):
t_true, t_pred = get_t_true_and_t_pred(task_id, y_true, y_pred, None, None)
t_label = target_labels[task_id]
if self.do_stats:
rm = RegressionMetrics(n_resamples=self.n_resamples)
rm.evaluate(
t_true.reshape(-1, 1),
t_pred.reshape(-1, 1),
target_labels=[t_label],
)
stat_dict = rm.get_stat_dict(t_label=t_label)
else:
stat_dict = {}
# create the plots
for plot_tag, plot in self.plots.items():
if "ciplot" in plot_tag:
self.plot_data[f"{t_label}_{plot_tag}"] = plot(stat_dict=stat_dict)
elif "regplot" in plot_tag:
self.plot_data[f"{t_label}_{plot_tag}"] = plot(
t_true,
t_pred,
y_pred_err=y_std,
xlabel=self.axes_labels[0],
ylabel=self.axes_labels[1],
title=f"{self.title}\nTask: {t_label}",
stat_dict=stat_dict,
pXC50=self.pXC50,
min_val=self.min_val,
max_val=self.max_val,
plot_errbars=self.plot_errbars,
fit_reg=self.fit_reg,
)
return self.plot_data
[docs] @staticmethod
def regplot(
y_true,
y_pred,
y_pred_err=None,
y_true_err=None,
data_labels=None,
xlabel="Measured",
ylabel="Predicted",
title="",
stat_dict={},
confidence_level=0.95,
pXC50=False,
min_val=None,
max_val=None,
fit_reg=False,
plot_errbars=False,
):
"""
Create a regression scatter plot with optional confidence intervals and statistics table.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
y_pred_err : array-like, optional
Prediction error bars.
y_true_err: array-like, optional
Experimental error bars.
data_labels : list, optional
Labels for each data point.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
title : str, optional
Title for the plot.
stat_dict : dict, optional
Dictionary of statistics to display on the plot.
confidence_level : float, optional
Confidence level for the regression line.
pXC50 : bool, optional
Whether to highlight pXC50 log unit ranges.
min_val : float, optional
Minimum axis value.
max_val : float, optional
Maximum axis value.
fit_reg : bool, optional
Whether to fit and plot a regression line.
plot_errbars : bool, optional
Whether to plot model error bars from ensemble predictions.
Returns
-------
seaborn.axisgrid.JointGrid
The regression plot object.
"""
title_font = 20
ax_font = 18
tick_font = 16
if min_val is None:
min_val = min(np.min(y_true), np.min(y_pred))
min_ax = min_val - 1
else:
min_ax = min_val
if max_val is None:
max_val = max(np.max(y_true), np.max(y_pred))
max_ax = max_val + 1
else:
max_ax = max_val
# set the limits to be the same for both axes
import seaborn as sns
g = sns.jointplot(
x=np.ravel(y_true),
y=np.ravel(y_pred),
kind="reg",
joint_kws={"ci": confidence_level * 100, "fit_reg": fit_reg},
color="teal",
height=10,
scatter_kws={"alpha": 0.3},
)
if y_pred_err is not None and plot_errbars:
g.ax_joint.errorbar(
x=np.ravel(y_true),
y=np.ravel(y_pred),
yerr=np.ravel(y_pred_err),
fmt="o",
color="teal",
alpha=0.3,
)
if y_true_err is not None and plot_errbars:
g.ax_joint.errorbar(
x=np.ravel(y_true),
y=np.ravel(y_pred),
xerr=np.ravel(y_true_err),
fmt="o",
color="teal",
alpha=0.3,
)
if data_labels is not None:
for i, label in enumerate(data_labels):
g.ax_joint.text(
x=np.ravel(y_true)[i],
y=np.ravel(y_pred)[i],
s=label,
fontsize=6,
color="black",
ha="right",
va="bottom",
)
g.figure.suptitle(title, fontsize=title_font)
g.ax_joint.set_aspect("equal", "box")
g.ax_joint.set_xlim(min_ax, max_ax)
g.ax_joint.set_ylim(min_ax, max_ax)
g.ax_joint.tick_params(axis="both", labelsize=tick_font)
# plot y = x line in dashed grey
g.ax_joint.plot(
[min_ax, max_ax], [min_ax, max_ax], linestyle="--", color="black"
)
# if pXC50 measure then plot the 0.5 and 1.0 log range unit
if pXC50:
g.ax_joint.fill_between(
[min_ax, max_ax],
[min_ax - 0.5, max_ax - 0.5],
[min_ax + 0.5, max_ax + 0.5],
color="gray",
alpha=0.2,
)
g.ax_joint.fill_between(
[min_ax, max_ax],
[min_ax - 1, max_ax - 1],
[min_ax + 1, max_ax + 1],
color="gray",
alpha=0.2,
)
g.ax_joint.set_xlabel(xlabel, fontsize=ax_font)
g.ax_joint.set_ylabel(ylabel, fontsize=ax_font)
# From the stat_dict, parse out the performance metric values and their labels to put into a table to print on the regression plot
if stat_dict:
conf_level = stat_dict.get("conf_level", None)
metric_names = stat_dict.get("metrics", [])
values = stat_dict.get("means", [])
lower_bounds = stat_dict.get("lower_ci", [])
upper_bounds = stat_dict.get("upper_ci", [])
table_data = []
# Format the metric values for readability
for name, val, low, high in zip(
metric_names, values, lower_bounds, upper_bounds
):
if None not in (val, low, high):
val_str = f"{val:.2f} [{low:.2f}, {high:.2f}]"
else:
val_str = "N/A"
table_data.append([name, val_str])
# Create the table
table = g.ax_joint.table(
cellText=table_data,
colLabels=["Metric", f"Value ± {int(conf_level * 100)}% CI"],
colWidths=[0.2, 0.3],
loc="upper left",
cellLoc="left",
)
table.scale(1, 1.8)
for key, cell in table.get_celld().items():
cell.set_fontsize(ax_font)
# Right align the metric values
for i in range(1, len(table_data) + 1):
table[i, 1].get_text().set_horizontalalignment("right")
g.ax_joint.set_box_aspect(1)
g.figure.tight_layout()
return g
[docs] @staticmethod
def ciplot(stat_dict={}):
"""
Create a confidence interval plot for regression metrics.
Parameters
----------
stat_dict : dict
Dictionary containing metrics, means, confidence intervals, and task name.
Returns
-------
matplotlib.figure.Figure
The confidence interval plot figure.
"""
metrics = stat_dict["metrics"]
means = stat_dict["means"]
lower_ci = stat_dict["lower_ci"]
upper_ci = stat_dict["upper_ci"]
conf_level = stat_dict["conf_level"]
task_name = stat_dict["task_name"]
title_font = 16
tick_font = 12
ax_font = 14
y_limits = {
"$R^2$": (0, 1),
"Kendall's $\\tau$": (0, 1),
"Spearman's $\\rho$": (0, 1),
}
n_metrics = len(metrics)
from matplotlib import pyplot as plt
fig, axes = plt.subplots(1, n_metrics, figsize=(8, n_metrics), sharex=False)
if n_metrics == 1:
axes = [axes] # Ensure it's iterable
for i, ax in enumerate(axes):
if i == 0:
ax.set_ylabel("Performance Metric Value", fontsize=ax_font)
metric = metrics[i]
y = means[i]
yerr = [[y - lower_ci[i]], [upper_ci[i] - y]]
ax.errorbar([metric], [y], yerr=yerr, fmt="o", capsize=8, color="green")
ax.tick_params(axis="both", labelsize=tick_font)
ax.yaxis.grid(True, linestyle="--", color="lightgray", alpha=0.6)
ax.set_xlim(-0.5, 0.5)
# Set fixed y-limits
if metric in ["MSE", "MAE"]:
upper = upper_ci[i] * 1.1 if upper_ci[i] > 0 else 1
ax.set_ylim(0, upper)
elif metric in y_limits:
ax.set_ylim(y_limits[metric])
fig.suptitle(
f"Evaluation of {task_name} with {int(conf_level * 100)}% Confidence Intervals",
fontsize=title_font,
)
fig.tight_layout()
return fig
[docs] def report(self, write=False, output_dir=None):
"""
Report the generated plots, optionally writing to disk.
Parameters
----------
write : bool, optional
Whether to write the plots to disk.
output_dir : str, optional
Output directory for the plots.
Returns
-------
dict
Dictionary of plot figures.
"""
if write:
self.write_report(output_dir)
return self.plot_data
[docs] def write_report(self, output_dir):
"""
Write the generated plots to PNG files and optionally log to wandb.
Parameters
----------
output_dir : str
Output directory for the plots.
"""
for plot_tag, plot in self.plot_data.items():
plot_path = output_dir / f"{plot_tag}.png"
plot.savefig(plot_path, dpi=self.dpi)
if self.use_wandb:
import wandb
wandb.log({plot_tag: wandb.Image(str(plot_path))})