Source code for openadmet.models.architecture.lgbm
"""LightGBM model implementations."""
from typing import ClassVar
import lightgbm as lgb
import numpy as np
from loguru import logger
from openadmet.models.architecture.model_base import PickleableModelBase, models
[docs]class LGBMModelBase(PickleableModelBase):
"""Base class for LightGBM models."""
# Meta parameters for this class
type: ClassVar[str]
mod_class: ClassVar[type]
# LGBM parameters
boosting_type: str = "gbdt"
num_leaves: int = 31
max_depth: int = -1
learning_rate: float = 0.1
n_estimators: int = 100
subsample_for_bin: int = 200000
objective: str | None = None
class_weight: str | None = None
min_split_gain: float = 0.0
min_child_weight: float = 0.001
min_child_samples: int = 20
subsample: float = 1.0
subsample_freq: int = 0
colsample_bytree: float = 1.0
reg_alpha: float = 0.0
reg_lambda: float = 0.0
random_state: int | None = None
n_jobs: int | None = None
importance_type: str = "split"
verbose: int = -1
[docs] def build(self):
"""Prepare the model."""
if not self.estimator:
self.estimator = self.mod_class(**self.model_dump())
else:
logger.warning("Model already exists, skipping build")
[docs] def train(self, X: np.ndarray, y: np.ndarray):
"""
Train the model.
Parameters
----------
X: np.ndarray
Training data features
y: np.ndarray
Training data values
"""
self.build()
self.estimator = self.estimator.fit(X, y)
[docs] def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""
Predict using the model.
Parameters
----------
X: np.ndarray
Featurized data to predict on
kwargs: dict
Additional keyword arguments to pass to the predict method of the LightGBM model
Returns
-------
np.ndarray
Predictions from the model
"""
if not self.estimator:
raise ValueError("Model not trained")
return np.expand_dims(self.estimator.predict(X), axis=1)
[docs]@models.register("LGBMRegressorModel")
class LGBMRegressorModel(LGBMModelBase):
"""LightGBM regression model."""
# Meta parameters for this class
type: ClassVar[str] = "LGBMRegressorModel"
mod_class: ClassVar[type] = lgb.LGBMRegressor
[docs]@models.register("LGBMClassifierModel")
class LGBMClassifierModel(LGBMModelBase):
"""LightGBM classification model."""
# Meta parameters for this class
type: ClassVar[str] = "LGBMClassifierModel"
mod_class: ClassVar[type] = lgb.LGBMClassifier
[docs] def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Predict using the model."""
if not self.estimator:
raise ValueError("Model not trained")
return self.estimator.predict_proba(X)