Source code for openadmet.models.trainer.trainer_base

"""Base class for trainers, allows for arbitrary training of models."""

from abc import ABC, abstractmethod
from typing import Any

from class_registry import ClassRegistry, RegistryKeyError
from pydantic import BaseModel

from openadmet.models.architecture.model_base import ModelBase

trainers = ClassRegistry(unique=True)


[docs]def get_trainer_class(model_type): """ Retrieve a trainer class from the registry by type. Parameters ---------- model_type : str The type of trainer to retrieve. Returns ------- TrainerBase The trainer class corresponding to the given type. """ try: feat_class = trainers.get_class(model_type) except RegistryKeyError: raise ValueError( f"Trainer type {model_type} not found in trainer catalouge," f"available trainers are {list(trainers.classes())}" ) return feat_class
[docs]class TrainerBase(BaseModel, ABC): """ Base class for trainers, allows for arbitrary training of models. Attributes ---------- _model : ModelBase The model to be trained. """ _model: ModelBase @property def model(self): """Return model to be trained.""" return self._model @model.setter def model(self, model): """Set model to be trained.""" self._model = model
[docs] @abstractmethod def build(): """Build trainer, to be implemented by subclasses.""" pass
[docs] @abstractmethod def train(self, X: Any, y: Any): """ Train the model, abstract method to be implemented by subclasses. Parameters ---------- X : Any Feature data. y : Any Target data. """ pass