PyTorch Lightning Training
PyTorch Lightning trainer implementation.
- class openadmet.models.trainer.lightning.LightningTrainer(*, max_epochs: int = 20, accelerator: str = 'gpu', devices: int = 1, use_wandb: bool = False, output_dir: Path = None, wandb_project: str = 'openadmet-testing', early_stopping: bool = False, early_stopping_patience: int = 10, early_stopping_mode: str = 'min', early_stopping_min_delta: float = 0.001, gradient_clip_val: float = 0.0, precision: int = 32, accumulate_grad_batches: int = 1, deterministic: bool = False, fast_dev_run: bool = False, limit_train_batches: float = 1.0, limit_val_batches: float = 1.0, wandb_logger: Any = None)[source]
Bases:
TrainerBaseTrainer for PyTorch models.
- Variables:
max_epochs (int) – The maximum number of epochs to train for.
accelerator (str) – The accelerator to use, e.g. ‘cpu’, ‘gpu’.
devices (int) – The number of devices to use, e.g. 1 for single GPU, -1 for all available GPUs.
use_wandb (bool) – Whether to use Weights & Biases for logging.
output_dir (Path) – The output directory to save logs and models.
wandb_project (str) – The Weights & Biases project name.
early_stopping (bool) – Whether to use early stopping.
early_stopping_patience (int) – The number of epochs with no improvement after which training will be stopped.
early_stopping_mode (str) – The mode for early stopping, either ‘min’ or ‘max’.
early_stopping_min_delta (float) – The minimum change in the monitored quantity to qualify as an improvement.
gradient_clip_val (float) – The value to clip gradients at.
precision (int) – The precision to use, e.g. 32, 16, or ‘bf16’.
accumulate_grad_batches (int) – The number of batches to accumulate gradients over.
deterministic (bool) – Whether to use deterministic algorithms.
fast_dev_run (bool) – Whether to run a single batch for debugging.
limit_train_batches (float) – The fraction of training batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number.
limit_val_batches (float) – The fraction of validation batches to use, e.g. 1.0 for all, 0.5 for half, or an integer for a fixed number. Default is 1.0.
wandb_logger (Any) – The Weights & Biases logger.
_logger (Any) – The logger.
_trainer (Any) – The PyTorch Lightning trainer.
_callbacks (Any) – The callbacks.
- build(no_val: bool = False)[source]
Build the model trainer.
- Parameters:
no_val (bool, optional) – If no validation set specified, aka training a no split model, by default False
- make_new() LightningTrainer[source]
Copy parameters to a new LightningTrainer instance.
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- model_post_init(context: Any, /) None
This function is meant to behave like a BaseModel method to initialize private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
- Args:
self: The BaseModel instance. context: The context.
- train(train_dataloader, val_dataloader)[source]
Train the model.
- Parameters:
train_dataloader (DataLoader) – The training data loader.
val_dataloader (DataLoader) – The validation data loader.
- Returns:
model – The trained model.
- Return type: