Source code for openadmet.models.architecture.model_base
"""Base classes for all models."""
import json
from abc import ABC, abstractmethod
from os import PathLike
from typing import Any, ClassVar
from class_registry import ClassRegistry, RegistryKeyError
from loguru import logger
from pydantic import BaseModel
from openadmet.models.drivers import DriverType
models = ClassRegistry(unique=True)
[docs]def get_mod_class(model_type):
"""
Get the model class from the registry by type.
Parameters
----------
model_type : str
The registered key for the model (e.g., ``"XGBRegressorModel"``).
Returns
-------
type
The model class corresponding to the given type.
Raises
------
ValueError
If ``model_type`` is not found in the model registry.
"""
from openadmet.models._registry_loader import load_group
load_group("models")
try:
feat_class = models.get_class(model_type)
except RegistryKeyError:
raise ValueError(f"Model type {model_type} not found in model catalouge")
return feat_class
[docs]class ModelBase(BaseModel, ABC):
"""Base class for all models."""
_estimator: Any = None
_model_json_name: ClassVar[str] = "model.json"
_n_tasks: int = 1
@property
def estimator(self):
"""Get the model estimator."""
return self._estimator
@estimator.setter
def estimator(self, value):
"""Set the model estimator."""
self._estimator = value
[docs] @abstractmethod
def build(self):
"""Prepare the model, abstract method to be implemented by subclasses."""
pass
[docs] @abstractmethod
def save(self, path: PathLike):
"""
Save the model, abstract method to be implemented by subclasses.
Parameters
----------
path: PathLike
Path to save the model to
"""
pass
[docs] @abstractmethod
def load(self, path: PathLike):
"""
Load the model, abstract method to be implemented by subclasses.
Parameters
----------
path: PathLike
Path to load the model from
"""
pass
[docs] @abstractmethod
def serialize(self, param_path: PathLike, serial_path: PathLike):
"""
Serialize the model, abstract method to be implemented by subclasses.
Parameters
----------
param_path: PathLike
Path to save the model parameters to
serial_path: PathLike
Path to save the model serialization to
"""
pass
[docs] @abstractmethod
def deserialize(self, param_path: PathLike, serial_path: PathLike):
"""
Deserialize the model, abstract method to be implemented by subclasses.
Parameters
----------
param_path: PathLike
Path to load the model parameters from
serial_path: PathLike
Path to load the model serialization from
"""
pass
[docs] @abstractmethod
def train(self):
"""Train the model, abstract method to be implemented by subclasses."""
[docs] @abstractmethod
def predict(self, input: Any):
"""
Predict using the model, abstract method to be implemented by subclasses.
Parameters
----------
input: Any
Input data to predict on
"""
pass
def __call__(self, *args, **kwargs):
"""Call the predict method when the model instance is called."""
return self.predict(*args, **kwargs)
def __eq__(self, value):
"""Compare two model instances for equality, ignoring the model itself."""
# exclude model from comparison
return self.model_dump(exclude={"estimator"}) == value.model_dump(
exclude={"estimator"}
)
[docs]class PickleableModelBase(ModelBase):
"""An sklearn model that can be pickled using joblib."""
# ClassVar for pickleable model
pickleable: ClassVar[bool] = True
_model_save_name: ClassVar[str] = "model.pkl"
_driver_type: DriverType = DriverType.SKLEARN
[docs] def save(self, path: PathLike):
"""
Save the model to a pickle file.
Parameters
----------
path: PathLike
Path to save the model to
"""
if self.estimator is None:
raise ValueError("Model is not built, cannot save")
import joblib
with open(path, "wb") as f:
joblib.dump(self.estimator, f)
[docs] def load(self, path: PathLike):
"""
Load the model from a pickle file.
Parameters
----------
path: PathLike
Path to load the model from
"""
import joblib
with open(path, "rb") as f:
self.estimator = joblib.load(f)
[docs] def make_new(self) -> "PickleableModelBase":
"""Copy parameters to a new model instance without copying the estimator."""
return self.__class__(**self.model_dump(exclude={"estimator"}))
@classmethod
def deserialize(
cls, param_path: PathLike = "model.json", serial_path: PathLike = "model.pkl"
):
"""
Create a model from parameters and a pickled model.
Parameters
----------
param_path: PathLike
Path to load the model parameters from
serial_path: PathLike
Path to load the pickled model from
Returns
-------
instance: PickleableModelBase
An instance of the PickleableModelBase class
"""
with open(param_path) as f:
mod_params = json.load(f)
instance = cls(**mod_params)
instance.build()
instance.load(serial_path)
return instance
[docs] def serialize(
self, param_path: PathLike = "model.json", serial_path: PathLike = "model.pkl"
):
"""
Save the model to a json file and a pickled file.
Parameters
----------
param_path: PathLike
Path to save the model parameters to
serial_path: PathLike
Path to save the pickled model to
"""
with open(param_path, "w") as f:
f.write(self.model_dump_json(indent=2))
self.save(serial_path)
# Re-export Lightning base classes using lazy module __getattr__ (PEP 562) so that
# importing model_base does NOT pull in torch or lightning.pytorch.
# The actual definitions live in lightning_model_base.
_LIGHTNING_EXPORTS = frozenset({"LightningModelBase", "LightningModuleBase"})
def __getattr__(name: str):
"""Lazily re-export Lightning base classes to avoid paying their import cost."""
if name in _LIGHTNING_EXPORTS:
from openadmet.models.architecture import lightning_model_base as _lmb
value = getattr(_lmb, name)
# Cache in module dict so subsequent accesses are direct
globals()[name] = value
return value
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
"ModelBase",
"PickleableModelBase",
"LightningModuleBase",
"LightningModelBase",
"models",
"get_mod_class",
]