Source code for openadmet.models.trainer.lightning

"""PyTorch Lightning trainer implementation."""

from pathlib import Path  # it is used in the main therefore i do not remove it
from typing import Any

import torch
from lightning import pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from loguru import logger
from openadmet.models.drivers import DriverType
from openadmet.models.trainer.trainer_base import TrainerBase, trainers


[docs]@trainers.register("LightningTrainer") class LightningTrainer(TrainerBase): """ Trainer for PyTorch models. Attributes ---------- max_epochs : int The maximum number of epochs to train for. accelerator : str The accelerator to use, e.g. 'cpu', 'gpu'. devices : int The number of devices to use, e.g. 1 for single GPU, -1 for all available GPUs. use_wandb : bool Whether to use Weights & Biases for logging. output_dir : Path The output directory to save logs and models. wandb_project : str The Weights & Biases project name. early_stopping : bool Whether to use early stopping. early_stopping_patience : int The number of epochs with no improvement after which training will be stopped. early_stopping_mode : str The mode for early stopping, either 'min' or 'max'. early_stopping_min_delta : float The minimum change in the monitored quantity to qualify as an improvement. gradient_clip_val : float The value to clip gradients at. precision : int The precision to use, e.g. 32, 16, or 'bf16'. accumulate_grad_batches : int The number of batches to accumulate gradients over. deterministic : bool Whether to use deterministic algorithms. fast_dev_run : bool Whether to run a single batch for debugging. limit_train_batches : float The fraction of training batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number. limit_val_batches : float The fraction of validation batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number. Default is 1.0. wandb_logger : Any The Weights & Biases logger. _logger : Any The logger. _trainer : Any The PyTorch Lightning trainer. _callbacks : Any The callbacks. """ max_epochs: int = 20 accelerator: str = "gpu" devices: int = 1 use_wandb: bool = False output_dir: Path = None wandb_project: str = "openadmet-testing" early_stopping: bool = False early_stopping_patience: int = 10 early_stopping_mode: str = "min" early_stopping_min_delta: float = 0.001 gradient_clip_val: float = 0.0 precision: int = 32 accumulate_grad_batches: int = 1 deterministic: bool = False fast_dev_run: bool = False limit_train_batches: float = 1.0 limit_val_batches: float = 1.0 wandb_logger: Any = None _logger: Any _trainer: Any _callbacks: Any = None _driver_type: DriverType = DriverType.LIGHTNING
[docs] def build(self, no_val: bool = False): """ Build the model trainer. Parameters ---------- no_val : bool, optional If no validation set specified, aka training a no split model, by default False """ # Initialize logging container self._logger = [] # Initialize the callbacks dict self._callbacks = {} if no_val: self.model.estimator.monitor_metric = "train_loss" fmtstring = ( "best-{epoch}-{val_loss:.4f}" if self.model.estimator.monitor_metric == "val_loss" else "best-{epoch}-{train_loss:.4f}" ) # Configure checkpoint callbacks checkpoint_callback = ModelCheckpoint( self.output_dir / "checkpoints", # Directory where model checkpoints will be saved fmtstring, # Filename format for checkpoints, including epoch and validation loss self.model.estimator.monitor_metric, # Metric used to select the best checkpoint (based on validation loss) mode="min", # Save the checkpoint with the lowest validation loss (minimization objective) save_last=True, # Always save the most recent checkpoint, even if it's not the best save_top_k=1, # Keep the top 1 checkpoints ) # Append the checkpointing callback to the callbacks dict self._callbacks[checkpoint_callback.__class__.__name__] = checkpoint_callback # Configure early stopping callback if self.early_stopping: early_stopping_callback = EarlyStopping( min_delta=self.early_stopping_min_delta, # Minimum change in the monitored quantity to qualify as an improvement monitor=self.model.estimator.monitor_metric, # Monitor validation loss for early stopping patience=self.early_stopping_patience, # Number of epochs with no improvement after which training will be stopped mode=self.early_stopping_mode, # Stop when validation loss stops decreasing ) self._callbacks[early_stopping_callback.__class__.__name__] = ( early_stopping_callback ) # Append wandb longer if requested if self.use_wandb: self.wandb_logger = WandbLogger( log_model=True, save_dir=self.output_dir, project=self.wandb_project ) self._logger.append(self.wandb_logger) # Append CSV logger self._logger.append(CSVLogger(self.output_dir / "logs", name="model")) # Initialize the PyTorch Lightning trainer self._trainer = pl.Trainer( logger=self._logger, enable_progress_bar=True, accelerator=self.accelerator, devices=self.devices, # Use GPU if available max_epochs=self.max_epochs, # number of epochs to train for callbacks=list(self._callbacks.values()), gradient_clip_val=self.gradient_clip_val, precision=self.precision, accumulate_grad_batches=self.accumulate_grad_batches, deterministic=self.deterministic, fast_dev_run=self.fast_dev_run, limit_train_batches=self.limit_train_batches, limit_val_batches=self.limit_val_batches, )
[docs] def train(self, train_dataloader, val_dataloader): """ Train the model. Parameters ---------- train_dataloader : DataLoader The training data loader. val_dataloader : DataLoader The validation data loader. Returns ------- model : TrainerBase The trained model. """ # Indicate that the model is being trained logger.debug(f"Training model {self.model.estimator}") # Fit model self._trainer.fit(self.model.estimator, train_dataloader, val_dataloader) # Load best checkpoint after training try: checkpoint = torch.load( self._callbacks["ModelCheckpoint"].best_model_path, weights_only=False ) self.model.estimator.load_state_dict(checkpoint["state_dict"]) except: logger.debug( "Warning: Training did not generate a best checkpoint. Using the latest checkpoint state-dict for evaluation." ) pass return self.model
[docs] def make_new(self) -> "LightningTrainer": """Copy parameters to a new LightningTrainer instance.""" return self.__class__(**self.__dict__)