PyTorch Lightning Training

PyTorch Lightning trainer implementation.

class openadmet.models.trainer.lightning.LightningTrainer(*, max_epochs: int = 20, accelerator: str = 'gpu', devices: int = 1, use_wandb: bool = False, output_dir: Path = None, wandb_project: str = 'openadmet-testing', early_stopping: bool = False, early_stopping_patience: int = 10, early_stopping_mode: str = 'min', early_stopping_min_delta: float = 0.001, gradient_clip_val: float = 0.0, precision: int = 32, accumulate_grad_batches: int = 1, deterministic: bool = False, fast_dev_run: bool = False, limit_train_batches: float = 1.0, limit_val_batches: float = 1.0, wandb_logger: Any = None)[source]

Bases: TrainerBase

Trainer for PyTorch models.

Variables:
  • max_epochs (int) – The maximum number of epochs to train for.

  • accelerator (str) – The accelerator to use, e.g. ‘cpu’, ‘gpu’.

  • devices (int) – The number of devices to use, e.g. 1 for single GPU, -1 for all available GPUs.

  • use_wandb (bool) – Whether to use Weights & Biases for logging.

  • output_dir (Path) – The output directory to save logs and models.

  • wandb_project (str) – The Weights & Biases project name.

  • early_stopping (bool) – Whether to use early stopping.

  • early_stopping_patience (int) – The number of epochs with no improvement after which training will be stopped.

  • early_stopping_mode (str) – The mode for early stopping, either ‘min’ or ‘max’.

  • early_stopping_min_delta (float) – The minimum change in the monitored quantity to qualify as an improvement.

  • gradient_clip_val (float) – The value to clip gradients at.

  • precision (int) – The precision to use, e.g. 32, 16, or ‘bf16’.

  • accumulate_grad_batches (int) – The number of batches to accumulate gradients over.

  • deterministic (bool) – Whether to use deterministic algorithms.

  • fast_dev_run (bool) – Whether to run a single batch for debugging.

  • limit_train_batches (float) – The fraction of training batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number.

  • limit_val_batches (float) – The fraction of validation batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number. Default is 1.0.

  • wandb_logger (Any) – The Weights & Biases logger.

  • _logger (Any) – The logger.

  • _trainer (Any) – The PyTorch Lightning trainer.

  • _callbacks (Any) – The callbacks.

accelerator: str
accumulate_grad_batches: int
build(no_val: bool = False)[source]

Build the model trainer.

Parameters:

no_val (bool, optional) – If no validation set specified, aka training a no split model, by default False

deterministic: bool
devices: int
early_stopping: bool
early_stopping_min_delta: float
early_stopping_mode: str
early_stopping_patience: int
fast_dev_run: bool
gradient_clip_val: float
limit_train_batches: float
limit_val_batches: float
make_new() LightningTrainer[source]

Copy parameters to a new LightningTrainer instance.

max_epochs: int
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_post_init(context: Any, /) None

This function is meant to behave like a BaseModel method to initialize private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

Args:

self: The BaseModel instance. context: The context.

output_dir: Path
precision: int
train(train_dataloader, val_dataloader)[source]

Train the model.

Parameters:
  • train_dataloader (DataLoader) – The training data loader.

  • val_dataloader (DataLoader) – The validation data loader.

Returns:

model – The trained model.

Return type:

TrainerBase

use_wandb: bool
wandb_logger: Any
wandb_project: str