Source code for openadmet.models.architecture.dummy
"""Dummy model implementations."""
from typing import ClassVar
import numpy as np
from loguru import logger
from sklearn.dummy import DummyClassifier, DummyRegressor
from openadmet.models.architecture.model_base import PickleableModelBase, models
[docs]class DummyModelBase(PickleableModelBase):
"""Base class for Dummy models, allows instantiation from parameters that are passable to the Dummy model classes."""
# Meta parameters for this class
type: ClassVar[str]
mod_class: ClassVar[type]
[docs] def build(self):
"""Prepare the model."""
if not self.estimator:
self.estimator = self.mod_class(**self.model_dump())
else:
logger.warning("Model already exists, skipping build")
[docs] def train(self, X: np.ndarray, y: np.ndarray):
"""
Train the model.
Parameters
----------
X: np.ndarray
Training data features
y: np.ndarray
Training data labels
"""
self.build()
y_arr = np.asarray(y)
if y_arr.ndim == 2 and y_arr.shape[1] == 1:
y_arr = y_arr.ravel()
self.estimator = self.estimator.fit(X, y_arr)
[docs] def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""
Predict using the model.
Parameters
----------
X: np.ndarray
Featurized data to predict on
kwargs: dict
Additional keyword arguments to pass to the predict method of the Dummy model
Returns
-------
np.ndarray
Predictions from the model
"""
if not self.estimator:
raise ValueError("Model not trained")
pred = self.estimator.predict(X)
if pred.ndim == 1:
pred = np.expand_dims(pred, axis=1)
return pred
[docs]@models.register("DummyRegressorModel")
class DummyRegressorModel(DummyModelBase):
"""
Dummy regression model.
Common parameters for dummy models can be found at:
https://scikit-learn.org/stable/api/sklearn.dummy.html
"""
# Meta parameters for this class
type: ClassVar[str] = "DummyRegressorModel"
mod_class: ClassVar[type] = DummyRegressor
# DummyRegressor parameters
strategy: str = "mean" # Default strategy for dummy models
constant: float | None = None # Default constant value for dummy models
quantile: float | None = None # Default quantile value for dummy models
[docs]@models.register("DummyClassifierModel")
class DummyClassifierModel(DummyModelBase):
"""
Dummy classification model.
Common parameters for dummy models can be found at:
https://scikit-learn.org/stable/api/sklearn.dummy.html
"""
# Meta parameters for this class
type: ClassVar[str] = "DummyClassifierModel"
mod_class: ClassVar[type] = DummyClassifier
# DummyClassifier parameters
strategy: str = "most_frequent" # Default strategy for dummy models
random_state: int | None = None # Default random state for dummy models
constant: int | str | None = None # Default constant value for dummy models