"""Regression metrics and plots for model evaluation."""
import json
from functools import partial
import numpy as np
import pandas as pd
import seaborn as sns
import wandb
from matplotlib import pyplot as plt
from pydantic import Field
from scipy.stats import kendalltau, spearmanr
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from openadmet.models.eval.eval_base import (
EvalBase,
evaluators,
get_t_true_and_t_pred,
)
from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict
# create partial functions for the scipy stats
nan_omit_ktau = partial(kendalltau, nan_policy="omit")
nan_omit_spearmanr = partial(spearmanr, nan_policy="omit")
[docs]def relative_absolute_error(y_true, y_pred):
"""
Compute Relative Absolute Error (RAE).
RAE = sum(|y_true - y_pred|) / sum(|y_true - mean(y_true)|).
Lower is better; RAE < 1.0 means the model outperforms a naive
mean predictor.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
Returns
-------
float
Relative absolute error.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
numerator = np.sum(np.abs(y_true - y_pred))
denominator = np.sum(np.abs(y_true - np.mean(y_true)))
if denominator == 0:
return np.nan
return numerator / denominator
[docs]def pct_within_1_log_unit(y_true, y_pred):
"""
Compute the fraction of predictions within +/-1 log unit of the true value.
Parameters
----------
y_true : array-like
True values (assumed to be on a log scale, e.g. pXC50).
y_pred : array-like
Predicted values.
Returns
-------
float
Fraction (0-1) of predictions within 1 log unit.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
return np.mean(np.abs(y_true - y_pred) <= 1.0)
[docs]@evaluators.register("RegressionMetrics")
class RegressionMetrics(EvalBase):
"""
Compute and report regression metrics such as MSE, MAE, R2, Kendall's tau, and Spearman's rho.
Attributes
----------
bootstrap_confidence_level : float
Confidence level for the bootstrap.
use_wandb : bool
Whether to use wandb for logging.
_evaluated : bool
Whether the model has been evaluated.
_metrics : dict
Dictionary of metrics to compute.
"""
bootstrap_confidence_level: float = Field(
0.95, description="Confidence level for the bootstrap"
)
use_wandb: bool = Field(False, description="Whether to use wandb")
pXC50: bool = Field(
False,
description="Whether targets are in pXC50/log units for log-based metrics",
)
_evaluated: bool = False
_metrics: dict = {
"mse": (mean_squared_error, False, "MSE"),
"mae": (mean_absolute_error, False, "MAE"),
"r2": (r2_score, False, "$R^2$"),
"ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"),
"spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"),
"rae": (relative_absolute_error, False, "RAE"),
}
@property
def active_metrics(self):
"""Return metrics applicable to the current target scale."""
metrics = dict(self._metrics)
if self.pXC50:
metrics["pct_within_1_log"] = (
pct_within_1_log_unit,
False,
"Fraction within ±1 log",
)
return metrics
[docs] def evaluate(
self,
y_true=None,
y_pred=None,
use_wandb=False,
tag=None,
target_labels=None,
**kwargs,
):
"""
Evaluate the regression model and compute metrics.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
use_wandb : bool, optional
Whether to log metrics to Weights & Biases.
tag : str, optional
Tag for the evaluation run.
target_labels : list of str, optional
List of target names.
kwargs : Dict
Additional keyword arguments.
Returns
-------
dict
Dictionary of computed metrics and confidence intervals.
"""
if y_true is None or y_pred is None:
raise ValueError("Must provide y_true and y_pred")
if isinstance(y_true, (pd.Series, pd.DataFrame)):
y_true = y_true.to_numpy()
# Ensure y_pred and y_true are 2D arrays for consistency
if y_pred.ndim == 1:
y_pred = y_pred.reshape(-1, 1)
if y_true.ndim == 1:
y_true = y_true.reshape(-1, 1)
n_tasks = y_true.shape[1]
if not (n_tasks == y_pred.shape[1]):
raise ValueError("y_true and y_pred must have the same number of tasks")
if target_labels is None:
target_labels = [f"task_{i}" for i in range(n_tasks)]
self.data = {"tag": tag}
if use_wandb:
self.use_wandb = use_wandb
for task_id in range(n_tasks):
t_true, t_pred = get_t_true_and_t_pred(task_id, y_true, y_pred, None, None)
t_label = target_labels[task_id]
self.data[t_label] = {}
for metric_tag, (metric, is_scipy, _) in self.active_metrics.items():
value, lower_ci, upper_ci = self.stat_and_bootstrap(
metric_tag,
t_pred,
t_true,
metric,
is_scipy_statistic=is_scipy,
confidence_level=self.bootstrap_confidence_level,
)
self.data[t_label][metric_tag] = {
"value": value,
"lower_ci": lower_ci,
"upper_ci": upper_ci,
"confidence_level": self.bootstrap_confidence_level,
}
if self.use_wandb:
for t_label in target_labels:
# make a table for the metrics
table = wandb.Table(
columns=[
"Metric",
"Value",
"Lower CI",
"Upper CI",
"Confidence Level",
]
)
for metric_tag in self.metric_names:
metric = self.data[t_label][metric_tag]
table.add_data(
metric_tag,
metric["value"],
metric["lower_ci"],
metric["upper_ci"],
metric["confidence_level"],
)
wandb.log({f"metrics_{t_label}": table})
self._evaluated = True
return self.data
@property
def metric_names(self):
"""
Return the metric names.
Returns
-------
list of str
List of metric names.
"""
return list(self.active_metrics.keys())
@property
def task_names(self):
"""
Return the task names.
Returns
-------
list of str
List of task names.
"""
return list(self.data.keys())
[docs] def report(self, write=False, output_dir=None):
"""
Report the evaluation results, optionally writing to disk.
Parameters
----------
write : bool, optional
Whether to write the report to disk.
output_dir : str, optional
Output directory for the report.
Returns
-------
dict
Dictionary of computed metrics.
"""
if write:
self.write_report(output_dir)
return self.data
[docs] def write_report(self, output_dir):
"""
Write the evaluation report to a JSON file and optionally log to wandb.
Parameters
----------
output_dir : str
Output directory for the report.
"""
# write to JSON
json_path = output_dir / "regression_metrics.json"
with open(json_path, "w") as f:
json.dump(self.data, f, indent=2)
# also log the json to wandb
if self.use_wandb:
artifact = wandb.Artifact(name="metrics_json", type="metric_json")
# Add a file to the artifact
artifact.add_file(json_path)
# Log the artifact
wandb.log_artifact(artifact)
[docs] def get_stat_caption(self, t_label):
"""
Get a formatted statistics caption for a given task.
Parameters
----------
t_label : str
Task label.
Returns
-------
str
Caption string with statistics.
"""
if not self._evaluated:
raise ValueError(
":( You must evaluate the model before the statistics caption can be made."
)
return _make_stat_caption(
data=self.data,
task_name=t_label,
metric_names=self.metric_names,
metrics=self.active_metrics,
confidence_level=self.bootstrap_confidence_level,
cv=False,
)
[docs] def get_stat_dict(self, t_label):
"""
Get a statistics dictionary for a given task.
Parameters
----------
t_label : str
Task label.
Returns
-------
dict
Dictionary of statistics for the task.
"""
if not self._evaluated:
raise ValueError(
"R'uh-r'oh! You must evaluate the model before the statistics dict can be made."
)
return _make_stat_dict(
data=self.data,
task_name=t_label,
metric_names=self.metric_names,
metrics=self._metrics,
confidence_level=self.bootstrap_confidence_level,
cv=False,
)
[docs]@evaluators.register("RegressionPlots")
class RegressionPlots(EvalBase):
"""
Generate and save regression plots such as regression scatter plots and confidence interval plots.
Attributes
----------
axes_labels : list of str
Labels for the axes.
title : str
Title for the plot.
do_stats : bool
Whether to compute and display statistics on the plots.
pXC50 : bool
Whether to highlight pXC50 log unit ranges.
plot_errbars : bool
Whether to plot error bars for ensemble predictions.
plots : dict
Dictionary of plot functions.
min_val : float
Minimum value for the axes.
max_val : float
Maximum value for the axes.
use_wandb : bool
Whether to use wandb for logging.
dpi : int
DPI for the plot.
"""
axes_labels: list[str] = Field(
["Measured", "Predicted"], description="Labels for the axes"
)
title: str = Field("Pred vs ", description="Title for the plot")
do_stats: bool = Field(True, description="Whether to do stats for the plot")
pXC50: bool = Field(
False,
description="Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit",
)
plot_errbars: bool = Field(
False, description="Whether to plot error bars for ensemble predictions"
)
plots: dict = {}
min_val: float = Field(None, description="Minimum value for the axes")
max_val: float = Field(None, description="Maximum value for the axes")
fit_reg: bool = Field(False, description="Whether to fit regression line")
use_wandb: bool = Field(False, description="Whether to use wandb")
dpi: int = Field(300, description="DPI for the plot")
[docs] def evaluate(
self,
y_true=None,
y_pred=None,
y_std=None,
use_wandb=False,
target_labels=None,
**kwargs,
):
"""
Generate regression plots and optionally compute statistics.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
y_std : array-like
Standard deviation of predictions if ensemble is specified.
use_wandb : bool, optional
Whether to log plots to Weights & Biases.
target_labels : list of str, optional
List of target names.
kwargs : Dict
Additional keyword arguments.
Returns
-------
dict
Dictionary of plot figures.
"""
if use_wandb:
self.use_wandb = use_wandb
if y_true is None or y_pred is None:
raise ValueError("Must provide y_true and y_pred")
if isinstance(y_true, (pd.Series, pd.DataFrame)):
y_true = y_true.to_numpy()
# Ensure y_pred and y_true are 2D arrays for consistency
if y_pred.ndim == 1:
y_pred = y_pred.reshape(-1, 1)
if y_true.ndim == 1:
y_true = y_true.reshape(-1, 1)
n_tasks = y_true.shape[1]
if not (n_tasks == y_pred.shape[1]):
raise ValueError("y_true and y_pred must have the same number of tasks")
if target_labels is None:
target_labels = [f"task_{i}" for i in range(n_tasks)]
self.plots = {"regplot": self.regplot, "ciplot": self.ciplot}
self.plot_data = {}
for task_id in range(n_tasks):
t_true, t_pred = get_t_true_and_t_pred(task_id, y_true, y_pred, None, None)
t_label = target_labels[task_id]
if self.do_stats:
rm = RegressionMetrics()
rm.evaluate(
t_true.reshape(-1, 1),
t_pred.reshape(-1, 1),
target_labels=[t_label],
)
stat_dict = rm.get_stat_dict(t_label=t_label)
else:
stat_dict = {}
# create the plots
for plot_tag, plot in self.plots.items():
if "ciplot" in plot_tag:
self.plot_data[f"{t_label}_{plot_tag}"] = plot(stat_dict=stat_dict)
elif "regplot" in plot_tag:
self.plot_data[f"{t_label}_{plot_tag}"] = plot(
t_true,
t_pred,
y_pred_err=y_std,
xlabel=self.axes_labels[0],
ylabel=self.axes_labels[1],
title=f"{self.title}\nTask: {t_label}",
stat_dict=stat_dict,
pXC50=self.pXC50,
min_val=self.min_val,
max_val=self.max_val,
plot_errbars=self.plot_errbars,
fit_reg=self.fit_reg,
)
return self.plot_data
[docs] @staticmethod
def regplot(
y_true,
y_pred,
y_pred_err=None,
y_true_err=None,
data_labels=None,
xlabel="Measured",
ylabel="Predicted",
title="",
stat_dict={},
confidence_level=0.95,
pXC50=False,
min_val=None,
max_val=None,
fit_reg=False,
plot_errbars=False,
):
"""
Create a regression scatter plot with optional confidence intervals and statistics table.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
y_pred_err : array-like, optional
Prediction error bars.
y_true_err: array-like, optional
Experimental error bars.
data_labels : list, optional
Labels for each data point.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
title : str, optional
Title for the plot.
stat_dict : dict, optional
Dictionary of statistics to display on the plot.
confidence_level : float, optional
Confidence level for the regression line.
pXC50 : bool, optional
Whether to highlight pXC50 log unit ranges.
min_val : float, optional
Minimum axis value.
max_val : float, optional
Maximum axis value.
fit_reg : bool, optional
Whether to fit and plot a regression line.
plot_errbars : bool, optional
Whether to plot model error bars from ensemble predictions.
Returns
-------
seaborn.axisgrid.JointGrid
The regression plot object.
"""
title_font = 20
ax_font = 18
tick_font = 16
if min_val is None:
min_val = min(np.min(y_true), np.min(y_pred))
min_ax = min_val - 1
else:
min_ax = min_val
if max_val is None:
max_val = max(np.max(y_true), np.max(y_pred))
max_ax = max_val + 1
else:
max_ax = max_val
# set the limits to be the same for both axes
g = sns.jointplot(
x=np.ravel(y_true),
y=np.ravel(y_pred),
kind="reg",
joint_kws={"ci": confidence_level * 100, "fit_reg": fit_reg},
color="teal",
height=10,
scatter_kws={"alpha": 0.3},
)
if y_pred_err is not None and plot_errbars:
g.ax_joint.errorbar(
x=np.ravel(y_true),
y=np.ravel(y_pred),
yerr=np.ravel(y_pred_err),
fmt="o",
color="teal",
alpha=0.3,
)
if y_true_err is not None and plot_errbars:
g.ax_joint.errorbar(
x=np.ravel(y_true),
y=np.ravel(y_pred),
xerr=np.ravel(y_true_err),
fmt="o",
color="teal",
alpha=0.3,
)
if data_labels is not None:
for i, label in enumerate(data_labels):
g.ax_joint.text(
x=np.ravel(y_true)[i],
y=np.ravel(y_pred)[i],
s=label,
fontsize=6,
color="black",
ha="right",
va="bottom",
)
g.figure.suptitle(title, fontsize=title_font)
g.ax_joint.set_aspect("equal", "box")
g.ax_joint.set_xlim(min_ax, max_ax)
g.ax_joint.set_ylim(min_ax, max_ax)
g.ax_joint.tick_params(axis="both", labelsize=tick_font)
# plot y = x line in dashed grey
g.ax_joint.plot(
[min_ax, max_ax], [min_ax, max_ax], linestyle="--", color="black"
)
# if pXC50 measure then plot the 0.5 and 1.0 log range unit
if pXC50:
g.ax_joint.fill_between(
[min_ax, max_ax],
[min_ax - 0.5, max_ax - 0.5],
[min_ax + 0.5, max_ax + 0.5],
color="gray",
alpha=0.2,
)
g.ax_joint.fill_between(
[min_ax, max_ax],
[min_ax - 1, max_ax - 1],
[min_ax + 1, max_ax + 1],
color="gray",
alpha=0.2,
)
g.ax_joint.set_xlabel(xlabel, fontsize=ax_font)
g.ax_joint.set_ylabel(ylabel, fontsize=ax_font)
# From the stat_dict, parse out the performance metric values and their labels to put into a table to print on the regression plot
if stat_dict:
conf_level = stat_dict.get("conf_level", None)
metric_names = stat_dict.get("metrics", [])
values = stat_dict.get("means", [])
lower_bounds = stat_dict.get("lower_ci", [])
upper_bounds = stat_dict.get("upper_ci", [])
table_data = []
# Format the metric values for readability
for name, val, low, high in zip(
metric_names, values, lower_bounds, upper_bounds
):
if None not in (val, low, high):
val_str = f"{val:.2f} [{low:.2f}, {high:.2f}]"
else:
val_str = "N/A"
table_data.append([name, val_str])
# Create the table
table = g.ax_joint.table(
cellText=table_data,
colLabels=["Metric", f"Value ± {int(conf_level * 100)}% CI"],
colWidths=[0.2, 0.3],
loc="upper left",
cellLoc="left",
)
table.scale(1, 1.8)
for key, cell in table.get_celld().items():
cell.set_fontsize(ax_font)
# Right align the metric values
for i in range(1, len(table_data) + 1):
table[i, 1].get_text().set_horizontalalignment("right")
g.ax_joint.set_box_aspect(1)
g.figure.tight_layout()
return g
[docs] @staticmethod
def ciplot(stat_dict={}):
"""
Create a confidence interval plot for regression metrics.
Parameters
----------
stat_dict : dict
Dictionary containing metrics, means, confidence intervals, and task name.
Returns
-------
matplotlib.figure.Figure
The confidence interval plot figure.
"""
metrics = stat_dict["metrics"]
means = stat_dict["means"]
lower_ci = stat_dict["lower_ci"]
upper_ci = stat_dict["upper_ci"]
conf_level = stat_dict["conf_level"]
task_name = stat_dict["task_name"]
title_font = 16
tick_font = 12
ax_font = 14
y_limits = {
"$R^2$": (0, 1),
"Kendall's $\\tau$": (0, 1),
"Spearman's $\\rho$": (0, 1),
}
n_metrics = len(metrics)
fig, axes = plt.subplots(1, n_metrics, figsize=(8, n_metrics), sharex=False)
if n_metrics == 1:
axes = [axes] # Ensure it's iterable
for i, ax in enumerate(axes):
if i == 0:
ax.set_ylabel("Performance Metric Value", fontsize=ax_font)
metric = metrics[i]
y = means[i]
yerr = [[y - lower_ci[i]], [upper_ci[i] - y]]
ax.errorbar([metric], [y], yerr=yerr, fmt="o", capsize=8, color="green")
ax.tick_params(axis="both", labelsize=tick_font)
ax.yaxis.grid(True, linestyle="--", color="lightgray", alpha=0.6)
ax.set_xlim(-0.5, 0.5)
# Set fixed y-limits
if metric in ["MSE", "MAE"]:
upper = upper_ci[i] * 1.1 if upper_ci[i] > 0 else 1
ax.set_ylim(0, upper)
elif metric in y_limits:
ax.set_ylim(y_limits[metric])
fig.suptitle(
f"Evaluation of {task_name} with {int(conf_level * 100)}% Confidence Intervals",
fontsize=title_font,
)
fig.tight_layout()
return fig
[docs] def report(self, write=False, output_dir=None):
"""
Report the generated plots, optionally writing to disk.
Parameters
----------
write : bool, optional
Whether to write the plots to disk.
output_dir : str, optional
Output directory for the plots.
Returns
-------
dict
Dictionary of plot figures.
"""
if write:
self.write_report(output_dir)
return self.plot_data
[docs] def write_report(self, output_dir):
"""
Write the generated plots to PNG files and optionally log to wandb.
Parameters
----------
output_dir : str
Output directory for the plots.
"""
for plot_tag, plot in self.plot_data.items():
plot_path = output_dir / f"{plot_tag}.png"
plot.savefig(plot_path, dpi=self.dpi)
if self.use_wandb:
wandb.log({plot_tag: wandb.Image(str(plot_path))})