Source code for openadmet.models.trainer.lightning

"""PyTorch Lightning trainer implementation."""

from pathlib import Path  # it is used in the main therefore i do not remove it
from typing import Any

from loguru import logger
from openadmet.models.drivers import DriverType
from openadmet.models.trainer.trainer_base import TrainerBase, trainers


[docs]@trainers.register("LightningTrainer") class LightningTrainer(TrainerBase): """ Trainer for PyTorch models. Attributes ---------- max_epochs : int The maximum number of epochs to train for. accelerator : str The accelerator to use, e.g. 'cpu', 'gpu'. devices : int The number of devices to use, e.g. 1 for single GPU, -1 for all available GPUs. use_wandb : bool Whether to use Weights & Biases for logging. output_dir : Path The output directory to save logs and models. wandb_project : str The Weights & Biases project name. early_stopping : bool Whether to use early stopping. early_stopping_patience : int The number of epochs with no improvement after which training will be stopped. early_stopping_mode : str The mode for early stopping, either 'min' or 'max'. early_stopping_min_delta : float The minimum change in the monitored quantity to qualify as an improvement. gradient_clip_val : float The value to clip gradients at. precision : int The precision to use, e.g. 32, 16, or 'bf16'. accumulate_grad_batches : int The number of batches to accumulate gradients over. deterministic : bool Whether to use deterministic algorithms. fast_dev_run : bool Whether to run a single batch for debugging. limit_train_batches : float The fraction of training batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number. limit_val_batches : float The fraction of validation batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number. Default is 1.0. wandb_logger : Any The Weights & Biases logger. _logger : Any The logger. _trainer : Any The PyTorch Lightning trainer. _callbacks : Any The callbacks. """ max_epochs: int = 20 accelerator: str = "gpu" devices: int = 1 use_wandb: bool = False output_dir: Path = None wandb_project: str = "openadmet-testing" early_stopping: bool = False early_stopping_patience: int = 10 early_stopping_mode: str = "min" early_stopping_min_delta: float = 0.001 gradient_clip_val: float = 0.0 precision: int = 32 accumulate_grad_batches: int = 1 deterministic: bool = False fast_dev_run: bool = False limit_train_batches: float = 1.0 limit_val_batches: float = 1.0 wandb_logger: Any = None _logger: Any _trainer: Any _callbacks: Any = None _driver_type: DriverType = DriverType.LIGHTNING
[docs] def build(self, no_val: bool = False): """ Build the model trainer. Parameters ---------- no_val : bool, optional If no validation set specified, aka training a no split model, by default False """ # Initialize logging container self._logger = [] from lightning import pytorch as pl from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.loggers import CSVLogger, WandbLogger # Initialize the callbacks dict self._callbacks = {} if no_val: self.model.estimator.monitor_metric = "train_loss" fmtstring = ( "best-{epoch}-{val_loss:.4f}" if self.model.estimator.monitor_metric == "val_loss" else "best-{epoch}-{train_loss:.4f}" ) # Configure checkpoint callbacks checkpoint_callback = ModelCheckpoint( self.output_dir / "checkpoints", # Directory where model checkpoints will be saved fmtstring, # Filename format for checkpoints, including epoch and validation loss self.model.estimator.monitor_metric, # Metric used to select the best checkpoint (based on validation loss) mode="min", # Save the checkpoint with the lowest validation loss (minimization objective) save_last=True, # Always save the most recent checkpoint, even if it's not the best save_top_k=1, # Keep the top 1 checkpoints ) # Append the checkpointing callback to the callbacks dict self._callbacks[checkpoint_callback.__class__.__name__] = checkpoint_callback # Configure early stopping callback if self.early_stopping: early_stopping_callback = EarlyStopping( min_delta=self.early_stopping_min_delta, # Minimum change in the monitored quantity to qualify as an improvement monitor=self.model.estimator.monitor_metric, # Monitor validation loss for early stopping patience=self.early_stopping_patience, # Number of epochs with no improvement after which training will be stopped mode=self.early_stopping_mode, # Stop when validation loss stops decreasing ) self._callbacks[early_stopping_callback.__class__.__name__] = ( early_stopping_callback ) # Append wandb longer if requested if self.use_wandb: self.wandb_logger = WandbLogger( log_model=True, save_dir=self.output_dir, project=self.wandb_project ) self._logger.append(self.wandb_logger) # Append CSV logger self._logger.append(CSVLogger(self.output_dir / "logs", name="model")) # Initialize the PyTorch Lightning trainer self._trainer = pl.Trainer( logger=self._logger, enable_progress_bar=True, accelerator=self.accelerator, devices=self.devices, # Use GPU if available max_epochs=self.max_epochs, # number of epochs to train for callbacks=list(self._callbacks.values()), gradient_clip_val=self.gradient_clip_val, precision=self.precision, accumulate_grad_batches=self.accumulate_grad_batches, deterministic=self.deterministic, fast_dev_run=self.fast_dev_run, limit_train_batches=self.limit_train_batches, limit_val_batches=self.limit_val_batches, )
[docs] def train(self, train_dataloader, val_dataloader): """ Train the model. Parameters ---------- train_dataloader : DataLoader The training data loader. val_dataloader : DataLoader The validation data loader. Returns ------- model : TrainerBase The trained model. """ # Indicate that the model is being trained logger.debug(f"Training model {self.model.estimator}") import torch # Fit model self._trainer.fit(self.model.estimator, train_dataloader, val_dataloader) # Load best checkpoint after training try: checkpoint = torch.load( self._callbacks["ModelCheckpoint"].best_model_path, weights_only=False ) self.model.estimator.load_state_dict(checkpoint["state_dict"]) except: logger.debug( "Warning: Training did not generate a best checkpoint. Using the latest checkpoint state-dict for evaluation." ) pass return self.model
[docs] def make_new(self) -> "LightningTrainer": """Copy parameters to a new LightningTrainer instance.""" return self.__class__(**self.__dict__)