Source code for openadmet.models.anvil.workflow_base
"""Base class for Anvil workflows."""
from abc import abstractmethod
from os import PathLike
from pathlib import Path
from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator
from openadmet.models.active_learning.ensemble_base import (
EnsembleBase,
)
from openadmet.models.anvil.specification import DataSpec, Metadata
from openadmet.models.architecture.model_base import ModelBase
from openadmet.models.eval.eval_base import EvalBase
from openadmet.models.features.feature_base import FeaturizerBase
from openadmet.models.registries import load_all # noqa: F401
from openadmet.models.split.split_base import SplitterBase
from openadmet.models.trainer.trainer_base import TrainerBase
from openadmet.models.transforms.transform_base import (
TransformBase,
)
[docs]class AnvilWorkflowBase(BaseModel):
"""
Base class for Anvil workflows.
Attributes
----------
metadata : Metadata
Metadata for the workflow.
data_spec : DataSpec
Data specification for the workflow.
transform : Optional[TransformBase]
Optional transform step.
split : SplitterBase
Data splitting strategy.
feat : FeaturizerBase
Feature extraction method.
model : ModelBase
The model to be used.
ensemble : Optional[EnsembleBase]
Optional ensemble model.
trainer : TrainerBase
The trainer for the model.
evals : list[EvalBase]
List of evaluation metrics.
model_kwargs : dict
Runtime model settings from the specification domain.
ensemble_kwargs : dict
Runtime ensemble settings from the specification domain.
debug : bool
Whether to run in debug mode.
"""
metadata: Metadata
data_spec: DataSpec
transform: Optional[TransformBase] = None # Optional transform step
split: SplitterBase
feat: FeaturizerBase
model: ModelBase
ensemble: EnsembleBase | None = None
trainer: TrainerBase
evals: list[EvalBase]
model_kwargs: dict = Field(default_factory=dict)
ensemble_kwargs: dict = Field(default_factory=dict)
debug: bool = False
resolved_output_dir: Path | None = None
[docs] @abstractmethod
def run(self, output_dir: PathLike = "anvil_training", debug: bool = False) -> Any:
"""
Run the workflow.
Parameters
----------
output_dir : PathLike, optional
Directory to save outputs, by default "anvil_training"
debug : bool, optional
Whether to run in debug mode, by default False
Returns
-------
Any
Result of the workflow run
"""
...
[docs] @model_validator(mode="after")
def check_multitask_compatibility(self) -> None:
"""
Validate that the model and data specification are compatible for multitask learning.
Raises
------
ValueError
If the model is multitask but the data specification does not support multitask learning.
"""
if self.model._n_tasks != len(self.data_spec.target_cols):
raise ValueError(
f"The model has {self.model._n_tasks} tasks but the data specification has {len(self.data_spec.target_cols)} target columns."
)
return self
[docs] @model_validator(mode="after")
def no_ensemble_cross_val(self) -> "AnvilWorkflowBase":
"""
Validate that ensemble models are not used with cross-validation.
Raises
------
ValueError
If an ensemble model is used with cross-validation.
Returns
-------
AnvilWorkflowBase
The validated workflow instance.
"""
doing_cv = any([v.is_cross_val for v in self.evals])
if self.ensemble is not None and doing_cv:
raise ValueError("Ensemble models cannot be used with cross-validation.")
return self
[docs] @model_validator(mode="after")
def check_model_trainer_compatibility(self) -> "AnvilWorkflowBase":
"""
Validate that the model and trainer are compatible.
Raises
------
ValueError
If the model and trainer driver types do not match.
Returns
-------
AnvilWorkflowBase
The validated workflow instance.
"""
if self.model._driver_type != self.trainer._driver_type:
raise ValueError(
f"Model driver type {self.model._driver_type} does not match trainer driver type {self.trainer._driver_type}."
)
return self
[docs] @model_validator(mode="after")
def check_trainer_cv_compatibility(self) -> "AnvilWorkflowBase":
"""
Validate that the trainer supports cross-validation if any evaluation requires it.
Raises
------
ValueError
If the trainer does not support cross-validation but an evaluation requires it.
Returns
-------
AnvilWorkflowBase
The validated workflow instance.
"""
cv_evals = [v for v in self.evals if v.is_cross_val]
for eval_instance in cv_evals:
if not self.trainer._driver_type == eval_instance._driver_type:
raise ValueError(
f"Trainer driver type {self.trainer._driver_type} does not match evaluation driver type {eval_instance._driver_type} required for cross-validation."
)
return self