"""PyTorch Lightning trainer implementation."""
from pathlib import Path # it is used in the main therefore i do not remove it
from typing import Any
import torch
from lightning import pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from loguru import logger
from openadmet.models.drivers import DriverType
from openadmet.models.trainer.trainer_base import TrainerBase, trainers
[docs]@trainers.register("LightningTrainer")
class LightningTrainer(TrainerBase):
"""
Trainer for PyTorch models.
Attributes
----------
max_epochs : int
The maximum number of epochs to train for.
accelerator : str
The accelerator to use, e.g. 'cpu', 'gpu'.
devices : int
The number of devices to use, e.g. 1 for single GPU, -1
for all available GPUs.
use_wandb : bool
Whether to use Weights & Biases for logging.
output_dir : Path
The output directory to save logs and models.
wandb_project : str
The Weights & Biases project name.
early_stopping : bool
Whether to use early stopping.
early_stopping_patience : int
The number of epochs with no improvement after which
training will be stopped.
early_stopping_mode : str
The mode for early stopping, either 'min' or 'max'.
early_stopping_min_delta : float
The minimum change in the monitored quantity to qualify
as an improvement.
gradient_clip_val : float
The value to clip gradients at.
precision : int
The precision to use, e.g. 32, 16, or 'bf16'.
accumulate_grad_batches : int
The number of batches to accumulate gradients over.
deterministic : bool
Whether to use deterministic algorithms.
fast_dev_run : bool
Whether to run a single batch for debugging.
limit_train_batches : float
The fraction of training batches to use, e.g. 1.0 for all,
0.5 for half, or an integer for a fixed number.
limit_val_batches : float
The fraction of validation batches to use, e.g. 1.0 for all,
0.5 for half, or an integer for a fixed number. Default is 1.0.
wandb_logger : Any
The Weights & Biases logger.
_logger : Any
The logger.
_trainer : Any
The PyTorch Lightning trainer.
_callbacks : Any
The callbacks.
"""
max_epochs: int = 20
accelerator: str = "gpu"
devices: int = 1
use_wandb: bool = False
output_dir: Path = None
wandb_project: str = "openadmet-testing"
early_stopping: bool = False
early_stopping_patience: int = 10
early_stopping_mode: str = "min"
early_stopping_min_delta: float = 0.001
gradient_clip_val: float = 0.0
precision: int = 32
accumulate_grad_batches: int = 1
deterministic: bool = False
fast_dev_run: bool = False
limit_train_batches: float = 1.0
limit_val_batches: float = 1.0
wandb_logger: Any = None
_logger: Any
_trainer: Any
_callbacks: Any = None
_driver_type: DriverType = DriverType.LIGHTNING
[docs] def build(self, no_val: bool = False):
"""
Build the model trainer.
Parameters
----------
no_val : bool, optional
If no validation set specified, aka training a no split model, by default False
"""
# Initialize logging container
self._logger = []
# Initialize the callbacks dict
self._callbacks = {}
if no_val:
self.model.estimator.monitor_metric = "train_loss"
fmtstring = (
"best-{epoch}-{val_loss:.4f}"
if self.model.estimator.monitor_metric == "val_loss"
else "best-{epoch}-{train_loss:.4f}"
)
# Configure checkpoint callbacks
checkpoint_callback = ModelCheckpoint(
self.output_dir
/ "checkpoints", # Directory where model checkpoints will be saved
fmtstring, # Filename format for checkpoints, including epoch and validation loss
self.model.estimator.monitor_metric, # Metric used to select the best checkpoint (based on validation loss)
mode="min", # Save the checkpoint with the lowest validation loss (minimization objective)
save_last=True, # Always save the most recent checkpoint, even if it's not the best
save_top_k=1, # Keep the top 1 checkpoints
)
# Append the checkpointing callback to the callbacks dict
self._callbacks[checkpoint_callback.__class__.__name__] = checkpoint_callback
# Configure early stopping callback
if self.early_stopping:
early_stopping_callback = EarlyStopping(
min_delta=self.early_stopping_min_delta, # Minimum change in the monitored quantity to qualify as an improvement
monitor=self.model.estimator.monitor_metric, # Monitor validation loss for early stopping
patience=self.early_stopping_patience, # Number of epochs with no improvement after which training will be stopped
mode=self.early_stopping_mode, # Stop when validation loss stops decreasing
)
self._callbacks[early_stopping_callback.__class__.__name__] = (
early_stopping_callback
)
# Append wandb longer if requested
if self.use_wandb:
self.wandb_logger = WandbLogger(
log_model=True, save_dir=self.output_dir, project=self.wandb_project
)
self._logger.append(self.wandb_logger)
# Append CSV logger
self._logger.append(CSVLogger(self.output_dir / "logs", name="model"))
# Initialize the PyTorch Lightning trainer
self._trainer = pl.Trainer(
logger=self._logger,
enable_progress_bar=True,
accelerator=self.accelerator,
devices=self.devices, # Use GPU if available
max_epochs=self.max_epochs, # number of epochs to train for
callbacks=list(self._callbacks.values()),
gradient_clip_val=self.gradient_clip_val,
precision=self.precision,
accumulate_grad_batches=self.accumulate_grad_batches,
deterministic=self.deterministic,
fast_dev_run=self.fast_dev_run,
limit_train_batches=self.limit_train_batches,
limit_val_batches=self.limit_val_batches,
)
[docs] def train(self, train_dataloader, val_dataloader):
"""
Train the model.
Parameters
----------
train_dataloader : DataLoader
The training data loader.
val_dataloader : DataLoader
The validation data loader.
Returns
-------
model : TrainerBase
The trained model.
"""
# Indicate that the model is being trained
logger.debug(f"Training model {self.model.estimator}")
# Fit model
self._trainer.fit(self.model.estimator, train_dataloader, val_dataloader)
# Load best checkpoint after training
try:
checkpoint = torch.load(
self._callbacks["ModelCheckpoint"].best_model_path, weights_only=False
)
self.model.estimator.load_state_dict(checkpoint["state_dict"])
except:
logger.debug(
"Warning: Training did not generate a best checkpoint. Using the latest checkpoint state-dict for evaluation."
)
pass
return self.model
[docs] def make_new(self) -> "LightningTrainer":
"""Copy parameters to a new LightningTrainer instance."""
return self.__class__(**self.__dict__)