Source code for openadmet.models.active_learning.ensemble_base
"""Base class for ensemble models."""
from typing import ClassVar
from class_registry import ClassRegistry, RegistryKeyError
from openadmet.models.architecture.model_base import ModelBase
ensemblers = ClassRegistry(unique=True)
[docs]def get_ensemble_class(ensemble_type):
"""Get the ensemble class."""
try:
ensemble_class = ensemblers.get_class(ensemble_type)
except RegistryKeyError:
raise ValueError(
f"Ensemble type {ensemble_type} not found in ensemble catalogue,"
f"available ensembles are {list(ensemblers.classes())}"
)
return ensemble_class
[docs]class EnsembleBase(ModelBase):
"""
Base class for ensemble models.
Attributes
----------
type : ClassVar[str]
The type of the ensemble model.
models : list
The list of models in the ensemble.
_calibration_model_save_name : ClassVar[str]
The name of the calibration model save file.
"""
type: ClassVar[str] = "EnsembleBase"
models: list = []
_calibration_model_save_name: ClassVar[str] = "calibration_model.pkl"
@property
def n_models(self):
"""Get the number of models in the ensemble."""
return len(self.models)
[docs] def build(self):
"""Is here as placeholder, as the committee will be built from provided models."""
pass
[docs] def from_params(self):
"""
Is here as placeholder.
This method doesn't really make sense for this class, as it is instantiated from already-trained models
or from the `train` method.
"""
raise NotImplementedError