Source code for openadmet.models.eval.cross_validation

"""Cross-validation evaluators for regression models."""

import json
from functools import partial
from collections import defaultdict
from typing import Any, ClassVar
import pandas as pd
import numpy as np
from loguru import logger
from pydantic import Field
from scipy.stats import norm
from sklearn.metrics import (
    make_scorer,
    mean_absolute_error,
    mean_squared_error,
    r2_score,
)
from sklearn.model_selection import GroupKFold, RepeatedKFold, cross_validate

from openadmet.models.eval.eval_base import EvalBase, evaluators, get_t_true_and_t_pred
from openadmet.models.eval.regression import (
    RegressionPlots,
    nan_omit_ktau,
    nan_omit_spearmanr,
    pct_within_1_log_unit,
    relative_absolute_error,
)
from openadmet.models.trainer.lightning import LightningTrainer
from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict
from openadmet.models.drivers import DriverType


[docs]def wrap_ktau(y_true, y_pred): """Wrap ktau nan omission.""" return nan_omit_ktau(y_true, y_pred).statistic
[docs]def wrap_spearmanr(y_true, y_pred): """Wrap spearmanR nan omission.""" return nan_omit_spearmanr(y_true, y_pred).correlation
[docs]def repeated_group_k_fold(X, y, groups, n_splits, n_repeats, random_state): """ Generate train/test indices for Repeated Group K-Fold cross-validation. Parameters ---------- X : array-like Feature data. y : array-like Target data. groups : array-like Group labels for the samples used while splitting the dataset. n_splits : int Number of splits for cross-validation. n_repeats : int Number of repeats for cross-validation. random_state : int Random state for reproducibility. Returns ------- train_inds : list of np.ndarray List of training set indices for each fold. test_inds : list of np.ndarray List of test set indices for each fold. """ train_inds = [] test_inds = [] # get reproducible set of random states to not generate same split each repeat prng = np.random.RandomState(random_state) split_rand_states = prng.randint(0, 10000, size=n_repeats) for i, split_rand_state in zip(range(n_repeats), split_rand_states): gss = GroupKFold( n_splits=n_splits, shuffle=True, random_state=split_rand_state, ) for train_idx, test_idx in gss.split(X, y, groups=groups): train_inds.append(train_idx) test_inds.append(test_idx) return train_inds, test_inds
[docs]class CrossValidationBase(EvalBase): """ Base class for cross-validation evaluators. Attributes ---------- _evaluated : bool Whether the evaluator has been run. axes_labels : list[str] Labels for the axes in plots. title : str Title for the plots. pXC50 : bool Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit. plot_errbars : bool Whether to plot error bars for ensemble predictions. confidence_level : float Confidence level for the confidence interval. _metrics : dict Dictionary of metrics to evaluate. min_val : float Minimum value for the axes. max_val : float Maximum value for the axes. """ is_cross_val: ClassVar[bool] = True _evaluated: bool = False axes_labels: list[str] = Field( ["Measured", "Predicted"], description="Labels for the axes" ) title: str = Field("Pred vs ", description="Title for the plot") pXC50: bool = Field( False, description="Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit", ) plot_errbars: bool = Field( False, description="Whether to plot error bars for ensemble predictions" ) confidence_level: float = Field( 0.95, description="Confidence level for the confidence interval" ) _metrics: dict = { "mse": (make_scorer(mean_squared_error), False, "MSE"), "mae": (make_scorer(mean_absolute_error), False, "MAE"), "r2": (make_scorer(r2_score), False, "$R^2$"), "ktau": (make_scorer(wrap_ktau), True, "Kendall's $\\tau$"), "spearmanr": (make_scorer(wrap_spearmanr), True, "Spearman's $\\rho$"), "rae": ( make_scorer(relative_absolute_error, greater_is_better=False), False, "RAE", ), } min_val: float = Field(None, description="Minimum value for the axes") max_val: float = Field(None, description="Maximum value for the axes") @property def active_metrics(self): """Return metrics applicable to the current target scale.""" metrics = dict(self._metrics) if self.pXC50: metrics["pct_within_1_log"] = ( make_scorer(pct_within_1_log_unit), False, "Fraction within ±1 log", ) return metrics @property def metric_names(self): """ Get the list of metric names. Returns ------- list of str List of metric names. """ return list(self.active_metrics.keys())
[docs]@evaluators.register("SKLearnRepeatedKFoldCrossValidation") class SKLearnRepeatedKFoldCrossValidation(CrossValidationBase): """ Cross-validation evaluator for sklearn models (single-task regression). Attributes ---------- n_splits : int Number of splits for cross-validation. n_repeats : int Number of repeats for cross-validation. random_state : int Random state for reproducibility. """ n_splits: int = Field(5, description="Number of splits for cross-validation") n_repeats: int = Field(1, description="Number of repeats for cross-validation") random_state: int = Field(42, description="Random state for reproducibility") _driver_type: DriverType = DriverType.SKLEARN
[docs] def evaluate( self, model=None, X_train=None, y_train=None, y_pred=None, y_true=None, X_all=None, y_all=None, groups=None, tag=None, target_labels=None, **kwargs, ): """ Evaluate the regression model using repeated K-fold cross-validation. Parameters ---------- model : sklearn-like estimator The regression model to evaluate. X_train : array-like Training features. y_train : array-like Training targets. y_pred : array-like Predicted values (not used in cross-validation, but required for interface). y_true : array-like True values (not used in cross-validation, but required for interface). X_all : array-like All data features. y_all : array-like All data targets. groups: array-like, optional Group labels for the samples used while splitting the dataset. tag : str, optional Tag for the evaluation run. target_labels : list of str, optional List of target names. kwargs : Dict Additional keyword arguments. Returns ------- dict Dictionary containing cross-validation metrics and confidence intervals. """ if ( model is None or X_train is None or y_train is None or y_pred is None or y_true is None or X_all is None or y_all is None ): raise ValueError( "model, X_train, y_train, y_pred, y_true, X_all, y_all must be provided" ) if isinstance(y_true, (pd.Series, pd.DataFrame)): y_true = y_true.to_numpy() # store the metric names and callables in dict suitable for sklearn cross_validate self.sklearn_metrics = {k: v[0] for k, v in self.active_metrics.items()} logger.info("Starting cross-validation") n_tasks = 1 if target_labels is None: target_labels = [f"task_{i}" for i in range(n_tasks)] if len(target_labels) != n_tasks: raise ValueError( f"Number of target labels ({len(target_labels)}) must match number of tasks ({n_tasks})" ) # run CV if groups is None: groups = np.array([i for i in range(X_all.shape[0])]) train_inds, test_inds = repeated_group_k_fold( X_all, y_all, groups, self.n_splits, self.n_repeats, self.random_state ) cv = iter(zip(train_inds, test_inds)) estimator = model.estimator # evaluate the model, storing the results # we do one job here to avoid issues with double parallelization # we prefer to parallelize model training over cross-validation scores = cross_validate( estimator, X_all, y_all, cv=cv, n_jobs=1, scoring=self.sklearn_metrics ) logger.info("Cross-validation complete") # remove the 'test_' prefix from the keys # also convert the numpy arrays to lists so they can be serialized to JSON clean_scores = {} for k, v in scores.items(): clean_scores[k.replace("test_", "")] = v # exclude fit_time and score_time exclude = ["fit_time", "score_time"] self.data = {"shape": [self.n_splits, self.n_repeats], "tag": tag} for task_id in range(n_tasks): t_label = target_labels[task_id] self.data[t_label] = {} for k, v in clean_scores.items() if k not in exclude else {}: # calculate the confidence interval, assuming normal distribution mean = v.mean() sigma = v.std(ddof=1) lower_ci, upper_ci = norm.interval( self.confidence_level, loc=mean, scale=sigma ) metric_data = {} metric_data["value"] = v.tolist() metric_data["mean"] = np.mean(v) metric_data["lower_ci"] = lower_ci metric_data["upper_ci"] = upper_ci metric_data["confidence_level"] = self.confidence_level self.data[t_label][k] = metric_data self._evaluated = True self.plots = { "cross_validation_regplot": RegressionPlots.regplot, "cross_validation_ciplot": RegressionPlots.ciplot, } self.plot_data = {} stat_dict = self.get_stat_dict(t_label=t_label) # create the plots for plot_tag, plot in self.plots.items(): if "ciplot" in plot_tag: self.plot_data[plot_tag] = plot(stat_dict=stat_dict) elif "regplot" in plot_tag: self.plot_data[plot_tag] = plot( y_true, y_pred, xlabel=self.axes_labels[0], ylabel=self.axes_labels[1], title=f"{self.title}\nTask: {t_label}", stat_dict=stat_dict, pXC50=self.pXC50, min_val=self.min_val, max_val=self.max_val, plot_errbars=self.plot_errbars, ) return self.data
[docs] def get_stat_caption(self, t_label): """ Get a formatted statistics caption for a given task. Parameters ---------- t_label : str Task label. Returns ------- str Caption string with statistics. """ if not self._evaluated: raise ValueError( ":( You must evaluate the model before the statistics caption can be made." ) return _make_stat_caption( data=self.data, task_name=t_label, metric_names=self.metric_names, metrics=self.active_metrics, confidence_level=self.confidence_level, cv=True, )
[docs] def get_stat_dict(self, t_label): """ Get a statistics dictionary for a given task. Parameters ---------- t_label : str Task label. Returns ------- dict Dictionary of statistics for the task. """ if not self._evaluated: raise ValueError( "R'uh-r'oh! You must evaluate the model before the statistics dict can be made." ) return _make_stat_dict( data=self.data, task_name=t_label, metric_names=self.metric_names, metrics=self.active_metrics, confidence_level=self.confidence_level, cv=True, )
[docs] def report(self, write=False, output_dir=None): """ Report the evaluation results, optionally writing to disk. Parameters ---------- write : bool, optional Whether to write the report to disk. output_dir : str, optional Output directory for the report. Returns ------- dict Dictionary of computed metrics. """ if write: self.write_report(output_dir) return self.data
[docs] def write_report(self, output_dir): """ Write the evaluation report and plots to disk. Parameters ---------- output_dir : str Output directory for the report and plots. """ # write to JSON with open(output_dir / "cross_validation_metrics.json", "w") as f: json.dump(self.data, f, indent=2) # write each plot to a file for plot_tag, plot in self.plot_data.items(): plot.savefig(output_dir / f"{plot_tag}.png", bbox_inches="tight", dpi=900)
[docs]@evaluators.register("PytorchLightningRepeatedKFoldCrossValidation") class PytorchLightningRepeatedKFoldCrossValidation(CrossValidationBase): """ Cross-validation evaluator for PyTorch Lightning models. Attributes ---------- n_splits : int Number of splits for cross-validation. n_repeats : int Number of repeats for cross-validation. random_state : int Random state for reproducibility. _evaluated : bool Whether the evaluator has been run. axes_labels : list[str] Labels for the axes in plots. title : str Title for the plots. pXC50 : bool Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit. confidence_level : float Confidence level for the confidence interval. _metrics : dict Dictionary of metrics to evaluate. min_val : float Minimum value for the axes. max_val : float Maximum value for the axes. use_wandb : bool Whether to use wandb for logging. """ n_splits: int = Field(5, description="Number of splits for cross-validation") n_repeats: int = Field(1, description="Number of repeats for cross-validation") random_state: int = Field(42, description="Random state for reproducibility") _evaluated: bool = False _driver_type: DriverType = DriverType.LIGHTNING axes_labels: list[str] = Field( ["Measured", "Predicted"], description="Labels for the axes" ) title: str = Field("Pred vs ", description="Title for the plot") pXC50: bool = Field( False, description="Whether to plot for pXC50, highlighting 0.5 and 1.0 log range unit", ) confidence_level: float = Field( 0.95, description="Confidence level for the confidence interval" ) _metrics: dict = { "mse": (mean_squared_error, False, "MSE"), "mae": (mean_absolute_error, False, "MAE"), "r2": (r2_score, False, "$R^2$"), "ktau": (wrap_ktau, True, "Kendall's $\\tau$"), "spearmanr": (wrap_spearmanr, True, "Spearman's $\\rho$"), "rae": (relative_absolute_error, False, "RAE"), } min_val: float = Field(None, description="Minimum value for the axes") max_val: float = Field(None, description="Maximum value for the axes") use_wandb: bool = Field(False, description="Whether to use wandb") @property def active_metrics(self): """Return metrics applicable to Lightning CV using raw metric callables.""" metrics = dict(self._metrics) if self.pXC50: metrics["pct_within_1_log"] = ( pct_within_1_log_unit, False, "Fraction within ±1 log", ) return metrics
[docs] def evaluate( self, model=None, X_train=None, y_true=None, y_pred=None, y_train=None, X_all=None, y_all=None, groups=None, featurizer=None, trainer=None, tag=None, use_wandb=False, target_labels=None, **kwargs, ): """ Evaluate the regression model using repeated K-fold cross-validation with PyTorch Lightning. Parameters ---------- model : LightningModelBase The PyTorch Lightning model to evaluate. X_train : array-like Training features. y_true : array-like True values for the full dataset. y_pred : array-like Predicted values for the full dataset. y_train : array-like Training targets. X_all : array-like All data features. y_all : array-like All data targets. groups: array-like, optional Group labels for the samples used while splitting the dataset. featurizer : object Featurizer instance for data preprocessing. trainer : LightningTrainer Trainer instance for model training. tag : str, optional Tag for the evaluation run. use_wandb : bool, optional Whether to use Weights & Biases logging. target_labels : list of str, optional List of target names. kwargs : Dict Additional keyword arguments. Returns ------- dict Dictionary containing cross-validation metrics and confidence intervals. """ logger.info("Starting cross-validation") if ( model is None or X_train is None or y_train is None or y_pred is None or y_true is None or tag is None or featurizer is None or trainer is None or X_all is None or y_all is None ): raise ValueError( "model, X_train, y_train, y_pred, y_true, X_all, y_all, and tag must be provided" ) if isinstance(y_true, (pd.Series, pd.DataFrame)): y_true = y_true.to_numpy() self.data = {"tag": tag} if use_wandb: self.use_wandb = use_wandb # store the metric names and callables in dict suitable for sklearn cross_validate self.sklearn_metrics = {k: v[0] for k, v in self.active_metrics.items()} if groups is None: groups = np.array([i for i in range(X_all.shape[0])]) train_inds, test_inds = repeated_group_k_fold( X_all, y_all, groups, self.n_splits, self.n_repeats, self.random_state ) cv = iter(zip(train_inds, test_inds)) self.data = { "shape": [self.n_splits, self.n_repeats], "tag": tag, } self._metric_data = {} # cast to numpy arrays X_all = X_all.to_numpy() y_all = y_all.to_numpy() # prepare containers for metrics n_tasks = y_all.shape[1] if target_labels is None: target_labels = [f"task_{i}" for i in range(n_tasks)] for task_id in range(n_tasks): t_label = target_labels[task_id] self._metric_data[t_label] = defaultdict(list) for fold, (fold_train_ids, fold_val_ids) in enumerate(cv): logger.info(f"Fold {fold}") X_train = X_all[fold_train_ids] y_train = y_all[fold_train_ids] X_val = X_all[fold_val_ids] y_val = y_all[fold_val_ids] # print shapes of matrices logger.debug(f"X_train shape: {X_train.shape}") logger.debug(f"y_train shape: {y_train.shape}") logger.debug(f"X_val shape: {X_val.shape}") logger.debug(f"y_val shape: {y_val.shape}") # Create a new featurizer and model for each fold fold_featurizer = featurizer.make_new() fold_train_dataloader, _, fold_train_scaler, _ = fold_featurizer.featurize( X_train, y_train ) fold_val_dataloader, _, _, _ = fold_featurizer.featurize(X_val, y_val) fold_model = model.make_new() fold_model.build(scaler=fold_train_scaler) fold_trainer = LightningTrainer( max_epochs=trainer.max_epochs, accelerator=trainer.accelerator, devices=trainer.devices, use_wandb=False, output_dir=trainer.output_dir / "cv" / f"fold_{str(fold)}", wandb_project=trainer.wandb_project, ) # Pass model to trainer fold_trainer.model = fold_model fold_trainer.build() # Pass the dataloaders to the trainer fold_model = fold_trainer.train(fold_train_dataloader, fold_val_dataloader) # evaluate the model y_pred_fold = fold_model.predict( fold_val_dataloader, accelerator=trainer.accelerator, devices=trainer.devices, ) # calculate the mean and confidence interval for each metric # loop over tasks and calculate the statistics if not (n_tasks == y_pred_fold.shape[1]): raise ValueError("y_true and y_pred must have the same number of tasks") for task_id in range(n_tasks): t_true, t_pred = get_t_true_and_t_pred( task_id, y_true, y_pred, y_val, y_pred_fold ) t_label = target_labels[task_id] for metric_name, metric_data in self.active_metrics.items(): metric_func, is_scipy_metric, _ = metric_data value = metric_func(t_true, t_pred) self._metric_data[t_label][metric_name].append(value) logger.info(f"Fold {fold} complete") # now we have the metric data for each task, calculate the mean and confidence interval for t_label in target_labels: task_data = self._metric_data[t_label] self.data[t_label] = {} for k, v in task_data.items(): # calculate the confidence interval, assuming normal distribution v = np.array(v) mean = v.mean() sigma = v.std(ddof=1) lower_ci, upper_ci = norm.interval( self.confidence_level, loc=mean, scale=sigma ) metric_data = {} metric_data["value"] = v.tolist() metric_data["mean"] = np.mean(v) metric_data["lower_ci"] = lower_ci metric_data["upper_ci"] = upper_ci metric_data["confidence_level"] = self.confidence_level self.data[t_label][k] = metric_data self._evaluated = True self.plots = { "cross_validation_regplot": RegressionPlots.regplot, "cross_validation_ciplot": RegressionPlots.ciplot, } self.plot_data = {} # now the plots for task_id in range(n_tasks): t_label = target_labels[task_id] t_true, t_pred = get_t_true_and_t_pred( task_id, y_true, y_pred, y_val, y_pred_fold ) stat_dict = self.get_stat_dict(t_label=t_label) # create the plots for plot_tag, plot in self.plots.items(): plot_tag_task = f"{plot_tag}_{t_label}" if "ciplot" in plot_tag_task: self.plot_data[plot_tag_task] = plot(stat_dict=stat_dict) elif "regplot" in plot_tag_task: self.plot_data[plot_tag_task] = plot( t_true, t_pred, xlabel=self.axes_labels[0], ylabel=self.axes_labels[1], title=f"{self.title}\nTask: {t_label}", stat_dict=stat_dict, pXC50=self.pXC50, min_val=self.min_val, max_val=self.max_val, ) return self.data
@property def task_names(self): """ Get the task names after evaluation. Returns ------- list of str List of task names. """ if not self._evaluated: raise ValueError("Must evaluate before getting task names") return list(self.data.keys())
[docs] def report(self, write=False, output_dir=None): """ Report the evaluation results, optionally writing to disk. Parameters ---------- write : bool, optional Whether to write the report to disk. output_dir : str, optional Output directory for the report. Returns ------- dict Dictionary of computed metrics. """ if write: self.write_report(output_dir) return self.data
[docs] def write_report(self, output_dir): """ Write the evaluation report and plots to disk. Parameters ---------- output_dir : str Output directory for the report and plots. """ # write to JSON with open(output_dir / "cross_validation_metrics.json", "w") as f: json.dump(self.data, f, indent=2) # write each plot to a file for plot_tag, plot in self.plot_data.items(): plot.savefig(output_dir / f"{plot_tag}.png", bbox_inches="tight", dpi=900)
[docs] def get_stat_caption(self, t_label): """ Get a formatted statistics caption for a given task. Parameters ---------- t_label : str Task label. Returns ------- str Caption string with statistics. """ return _make_stat_caption( data=self.data, task_name=t_label, metric_names=self.metric_names, metrics=self.active_metrics, confidence_level=self.confidence_level, cv=True, )
[docs] def get_stat_dict(self, t_label): """ Get a statistics dictionary for a given task. Parameters ---------- t_label : str Task label. Returns ------- dict Dictionary of statistics for the task. """ return _make_stat_dict( data=self.data, task_name=t_label, metric_names=self.metric_names, metrics=self.active_metrics, confidence_level=self.confidence_level, cv=True, )