Source code for openadmet.models.trainer.sklearn

"""Trainers for sklearn models."""

from typing import Any

from loguru import logger
from sklearn.model_selection import GridSearchCV

from openadmet.models.trainer.trainer_base import TrainerBase, trainers
from openadmet.models.drivers import DriverType


[docs]class SKLearnTrainer(TrainerBase): """Base trainer for sklearn models.""" _driver_type: DriverType = DriverType.SKLEARN
[docs]@trainers.register("SKLearnBasicTrainer") class SKlearnBasicTrainer(SKLearnTrainer): """Basic trainer for sklearn models."""
[docs] def train(self, X: Any, y: Any): """ Train the model. Parameters ---------- X : Any Feature data. y : Any Target data. Returns ------- ModelBase The trained model. """ sklearn_model = self.model.estimator sklearn_model.fit(X, y) self.model.estimator = sklearn_model return self.model
[docs] def build(self): """Unused method for sklearn models.""" pass
[docs]class SKLearnSearchTrainer(SKLearnTrainer): """ Trainer for sklearn models with search. Attributes ---------- search : Any The search object (e.g., GridSearchCV). """ _search: Any @property def search(self): """Return search object (e.g., GridSearchCV).""" return self._search @search.setter def search(self, value): """Set search object (e.g., GridSearchCV).""" self._search = value
[docs] def build(self): """Unused method for sklearn models.""" pass
[docs]@trainers.register("SKLearnGridSearchTrainer") class SKLearnGridSearchTrainer(SKLearnSearchTrainer): """ Trainer for sklearn models with grid search. Attributes ---------- param_grid : dict The parameter grid for grid search. """ param_grid: dict = {}
[docs] def train(self, X: Any, y: Any): """ Train the model. Parameters ---------- X : Any Featurized data. y : Any Target data. Returns ------- ModelBase The trained model. """ sklearn_model = self.model.estimator self.search = GridSearchCV(sklearn_model, param_grid=self.param_grid) self.search.fit(X, y) # Set the params and model to the best found self.model.estimator = self.search.best_estimator_ self.model.__dict__.update(self.model.estimator.get_params()) logger.info(f"Best params: {self.model.estimator.get_params()}") return self.model