"""Specification models for Anvil workflows."""
from os import PathLike
from pathlib import Path
from typing import ClassVar, Literal, Optional, Union
import fsspec
import intake
import jinja2
import pandas as pd
import yaml
from loguru import logger
from pydantic import BaseModel, EmailStr, Field, field_validator, model_validator
from openadmet.models.active_learning.ensemble_base import (
get_ensemble_class,
)
from openadmet.models.architecture.model_base import get_mod_class
from openadmet.models.drivers import DriverType
from openadmet.models.eval.eval_base import get_eval_class
from openadmet.models.features.feature_base import get_featurizer_class
from openadmet.models.registries import load_all # noqa: F401
from openadmet.models.split.split_base import get_splitter_class
from openadmet.models.trainer.trainer_base import get_trainer_class
from openadmet.models.transforms.transform_base import (
get_transform_class,
)
_SECTION_CLASS_GETTERS = {
"feat": get_featurizer_class,
"model": get_mod_class,
"ensemble": get_ensemble_class,
"split": get_splitter_class,
"eval": get_eval_class,
"train": get_trainer_class,
"transform": get_transform_class,
"INVALID": lambda x: None,
}
[docs]class DataSpec(BaseModel):
"""
Data specification for the workflow.
Attributes
----------
type : str
The type of data source (e.g., 'csv', 'yaml').
resource : str
The path or URL to the data resource.
cat_entry : Optional[str]
The catalog entry name if the resource is a YAML catalog.
target_cols : Union[str, list[str]]
The target column(s) in the dataset.
input_col : str
The input column in the dataset.
anvil_dir : Optional[str]
The base directory for relative paths.
dropna : Optional[bool]
Whether to drop rows with NaN values.
train_resource : Optional[str]
The path or URL to the training data resource (if using separate train/test).
test_resource : Optional[str]
The path or URL to the testing data resource (if using separate train/test).
val_resource : Optional[str]
The path or URL to the validation data resource (if using separate train/test).
_catalog : Optional[intake.catalog.Catalog]
The intake catalog object if the resource is a YAML file.
_using_train_test : bool
Whether using separate train and test resources.
"""
type: str
resource: Optional[str] = None
cat_entry: Optional[str] = None
target_cols: Union[str, list[str]]
input_col: str
anvil_dir: Optional[str] = None
dropna: Optional[bool] = False
train_resource: Optional[str] = None
test_resource: Optional[str] = None
val_resource: Optional[str] = None
_catalog: Optional[intake.catalog.Catalog] = None
_using_train_test: bool = False
@property
def using_train_test(self):
"""Whether using separate train and test resources."""
return self._using_train_test
[docs] @model_validator(mode="after")
def check_resource_test_train(self):
"""Ensure that either resource or train/test/val resources are provided, not both."""
if self.resource and (
self.train_resource or self.test_resource or self.val_resource
):
raise ValueError(
"Specify either `resource` or `train_resource`/`test_resource`/`val_resource`, not both."
)
if self.train_resource or self.test_resource or self.val_resource:
if not (self.train_resource and self.test_resource):
raise ValueError(
"`train_resource` and `test_resource` must both be specified when using separate resources."
)
self._using_train_test = True
return self
@field_validator("target_cols", mode="before")
@classmethod
def check_target_cols_input(cls, v):
"""Ensure target_cols is always a list."""
if isinstance(v, str):
return [v]
else:
return v
# validator to template the resource with ANVIL_DIR if present
[docs] @model_validator(mode="after")
def template_resource(self):
"""
Template the resource with ANVIL_DIR if present.
Returns
-------
self : DataSpec
The DataSpec instance with the templated resource.
"""
if self.anvil_dir:
if self.resource:
template = jinja2.Template(self.resource)
self.resource = template.render(ANVIL_DIR=self.anvil_dir)
return self
[docs] def template_anvil_dir(self, anvil_dir: Path):
"""Template all resources with ANVIL_DIR if present."""
self.anvil_dir = anvil_dir
for attr in ["resource", "train_resource", "test_resource", "val_resource"]:
value = getattr(self, attr, None)
if value:
setattr(self, attr, jinja2.Template(value).render(ANVIL_DIR=anvil_dir))
[docs] def read(self) -> tuple[pd.Series, pd.Series]:
"""
Read the data from the resource.
Returns
-------
input: pd.Series
The input data (e.g., SMILES strings)
targets: pd.Series
The target data (e.g., properties to predict)
"""
return (
self._read_train_test_val()
if self._using_train_test
else self._read_single_resource()
)
@staticmethod
def _read_csv_or_parquet(resource: str) -> pd.DataFrame:
"""Read data from a CSV or Parquet resource."""
if resource.endswith(".csv"):
return intake.open_csv(resource).read()
elif any(resource.endswith(x) for x in [".parquet", ".pq", ".pqt"]):
return intake.open_parquet(resource).read()
raise ValueError(f"Unsupported resource type: {resource}")
def _read_train_test_val(self) -> tuple[pd.Series, ...]:
"""Read data from separate train/test/validation resources."""
if not self.train_resource or not self.test_resource:
raise ValueError("Both train_resource and test_resource must be specified.")
def read_split(resource: str, split_name: str) -> pd.DataFrame:
if resource.endswith((".yaml", ".yml")):
raise ValueError(
"YAML catalogs not supported with train/test resources."
)
data = self._read_csv_or_parquet(resource)
if "_split" in data.columns:
raise ValueError(
f"{split_name.capitalize()} data should not contain a '_split' column."
)
data["_split"] = split_name
return data
# Read and combine data
splits_to_read = [(self.train_resource, "train"), (self.test_resource, "test")]
if self.val_resource:
splits_to_read.append((self.val_resource, "val"))
combined = pd.concat(
[read_split(resource, split) for resource, split in splits_to_read]
)
target_cols = (
self.target_cols
if isinstance(self.target_cols, list)
else [self.target_cols]
)
combined = combined[[self.input_col] + target_cols + ["_split"]]
# Handle NaN values
n_before = len(combined)
if self.dropna:
combined = combined.dropna().reset_index(drop=True)
logger.info(
f"{n_before} total rows. {n_before - len(combined)} NaN rows were dropped."
)
else:
logger.info(f"{n_before} total rows. 0 NaN rows were dropped.")
# Split and return (X values, then Y values)
train = combined[combined["_split"] == "train"]
test = combined[combined["_split"] == "test"]
val = combined[combined["_split"] == "val"] if self.val_resource else None
X_train = train[self.input_col]
X_test = test[self.input_col]
X_val = val[self.input_col] if val is not None else None
y_train = train[self.target_cols]
y_test = test[self.target_cols]
y_val = val[self.target_cols] if val is not None else None
# also return full X, y for reference
X = combined[self.input_col]
y = combined[self.target_cols]
return X_train, X_val, X_test, y_train, y_val, y_test, X, y
def _read_single_resource(self) -> tuple[pd.Series, pd.Series]:
"""Read data from a single resource."""
# Read data
if self.resource.endswith((".yaml", ".yml")):
if not self.cat_entry:
raise ValueError("cat_entry must be specified for YAML resources.")
self._catalog = intake.open_catalog(self.resource)
if self.cat_entry not in self._catalog:
raise ValueError(
f"cat_entry '{self.cat_entry}' not found in catalog '{self.resource}'."
)
data = self._catalog[self.cat_entry]().read()
else:
data = self._read_csv_or_parquet(self.resource)
# Select and clean columns
target_cols = (
self.target_cols
if isinstance(self.target_cols, list)
else [self.target_cols]
)
combined = data[[self.input_col] + target_cols]
n_before = len(combined)
if self.dropna:
combined = combined.dropna().reset_index(drop=True)
logger.info(
f"{n_before} total rows. {n_before - len(combined)} NaN rows were dropped."
)
else:
logger.info(f"{n_before} total rows. 0 NaN rows were dropped.")
return combined[self.input_col], combined[self.target_cols]
@property
def catalog(self):
"""Get the intake catalog if the resource is a YAML file."""
return self._catalog
[docs] def to_yaml(self, path, **storage_options):
"""
Write specification to YAML file.
Parameters
----------
path : str or PathLike
The file path to write the YAML content to.
storage_options : dict, optional
Additional options to pass to the file system (e.g., for S3, GCS).
"""
with fsspec.open(path, "w", **storage_options) as stream:
yaml.safe_dump(self.model_dump(), stream)
@classmethod
def from_yaml(cls, path, **storage_options):
"""
Load specification from YAML file.
Parameters
----------
path : str or PathLike
The file path to read the YAML content from.
storage_options : dict, optional
Additional options to pass to the file system (e.g., for S3, GCS).
Returns
-------
instance : DataSpec
An instance of the DataSpec class populated with data from the YAML file.
"""
of = fsspec.open(path, "r", **storage_options)
with of as stream:
data = yaml.safe_load(stream)
return cls(**data)
[docs]class SpecBase(BaseModel):
"""Base class for specifications."""
[docs] def to_yaml(self, path, **storage_options):
"""
Write specification to YAML file.
Parameters
----------
path : str or PathLike
The file path to write the YAML content to.
storage_options : dict, optional
Additional options to pass to the file system (e.g., for S3, GCS).
"""
# Open file stream
with fsspec.open(path, "w", **storage_options) as stream:
# Safe dump the model to stream
yaml.safe_dump(self.model_dump(), stream)
@classmethod
def from_yaml(cls, path, **storage_options):
"""
Load specification from YAML file.
Parameters
----------
path : str or PathLike
The file path to read the YAML content from.
storage_options : dict, optional
Additional options to pass to the file system (e.g., for S3, GCS)
Returns
-------
instance : SpecBase
An instance of the specification class populated with data from the YAML file.
"""
# Open file stream
with fsspec.open(path, "r", **storage_options) as stream:
# Safe load the model from stream
data = yaml.safe_load(stream)
# Pass YAML content to class constructor
return cls(**data)
[docs]class AnvilSection(SpecBase):
"""
Anvil specification section base class.
Attributes
----------
type : Optional[str]
The type of the section.
params : dict
The parameters for the section.
section_name : ClassVar[str]
The name of the section.
"""
type: str | None = None
params: dict = {}
section_name: ClassVar[str] = "INVALID"
[docs] def to_class(self):
"""
Convert the specification to the corresponding class instance.
Returns
-------
instance : object
An instance of the class corresponding to the section type.
"""
return _SECTION_CLASS_GETTERS[self.section_name](self.type)(**self.params)
[docs]class SplitSpec(AnvilSection):
"""Data split specification."""
section_name: ClassVar[str] = "split"
[docs]class FeatureSpec(AnvilSection):
"""
Featurization specification.
Attributes
----------
section_name : ClassVar[str]
The name of the section.
type : Optional[str]
The type of featurizer to use.
params : dict
The parameters for the featurizer.
"""
section_name: ClassVar[str] = "feat"
[docs]class ModelSpec(AnvilSection):
"""
Model specification.
Attributes
----------
section_name : ClassVar[str]
The name of the section.
param_path : Optional[str]
The path to the model parameters file.
serial_path : Optional[str]
The path to the model serialization file.
freeze_weights : Optional[dict]
A dictionary specifying which layers to freeze during training.
"""
section_name: ClassVar[str] = "model"
param_path: str | None = None
serial_path: str | None = None
freeze_weights: dict | None = None
[docs] @model_validator(mode="after")
def check_paths(self):
"""
Ensure both param_path and serial_path are provided together.
Returns
-------
self : ModelSpec
The validated ModelSpec instance.
"""
# Both specified
if self.param_path and self.serial_path:
return self
# Neither specified
if not self.param_path and not self.serial_path:
return self
raise ValueError(
"Both `param_path` and `serial_path` must be provided together."
)
[docs] @model_validator(mode="after")
def check_freeze_weights(self):
"""
Ensure freeze weights is supplied for only applicable model types.
Returns
-------
self : ModelSpec
The validated ModelSpec instance.
"""
# Check if weight freezing selected
if self.freeze_weights:
# Attempt freezing model weights
try:
model = self.to_class()
model.build()
model.freeze_weights()
# Raise error here if not implemented
except NotImplementedError:
raise ValueError(f"Weight freezing not implemented for {self.type}.")
return self
[docs] def template_anvil_dir(self, anvil_dir: Path):
"""Template param_path and serial_path with ANVIL_DIR."""
for attr in ["param_path", "serial_path"]:
value = getattr(self, attr, None)
if value:
setattr(self, attr, jinja2.Template(value).render(ANVIL_DIR=anvil_dir))
[docs]class EnsembleSpec(AnvilSection):
"""
Ensemble specification.
Attributes
----------
section_name : ClassVar[str]
The name of the section.
n_models : int
The number of models in the ensemble.
calibration_method : str
The calibration method to use.
param_paths : Optional[list[str]]
The list of parameter file paths for the ensemble models.
serial_paths : Optional[list[str]]
The list of serialization file paths for the ensemble models.
"""
section_name: ClassVar[str] = "ensemble"
n_models: int
calibration_method: str | None = None
use_bagging: bool = False
param_paths: list[str] | None = None
serial_paths: list[str] | None = None
@field_validator("calibration_method")
def check_method(cls, value):
"""Validate the calibration method."""
allowed = ["isotonic-regression", "scaling-factor", None]
if value not in allowed:
raise ValueError(
f"Invalid calibration method: {value}. Valid options are: {allowed}."
)
return value
@field_validator("n_models")
def check_n_models(cls, value):
"""Ensure ensemble has more than one model."""
if value < 2:
raise ValueError("Ensemble must have more than one model.")
return value
[docs] @model_validator(mode="after")
def check_paths(self):
"""Ensure both param_paths and serial_paths are provided together."""
# Both specified
if self.param_paths and self.serial_paths:
# Check lengths match
if len(self.param_paths) != len(self.serial_paths):
raise ValueError(
"Parameter and serial paths must have the same length."
)
# Check matches model count
if len(self.param_paths) != self.n_models:
raise ValueError(
f"Number of parameter ({len(self.param_paths)}) and serial paths ({len(self.serial_paths)}) must "
f"match the number of models ({self.n_models})."
)
return self
# Neither specified
if not self.param_paths and not self.serial_paths:
return self
raise ValueError(
"Both `param_paths` and `serial_paths` must be provided together."
)
[docs] def template_anvil_dir(self, anvil_dir: Path):
"""Template param_paths and serial_paths with ANVIL_DIR."""
for attr in ["param_paths", "serial_paths"]:
values = getattr(self, attr, None)
if values:
setattr(
self,
attr,
[jinja2.Template(v).render(ANVIL_DIR=anvil_dir) for v in values],
)
[docs]class TrainerSpec(AnvilSection):
"""Trainer specification."""
section_name: ClassVar[str] = "train"
[docs]class EvalSpec(AnvilSection):
"""Evaluation specification."""
section_name: ClassVar[str] = "eval"
[docs]class ProcedureSpec(SpecBase):
"""Procedure specification."""
section_name: ClassVar[str] = "procedure"
split: SplitSpec
feat: FeatureSpec
model: ModelSpec
ensemble: EnsembleSpec | None = None
train: TrainerSpec
transform: Optional[TransformSpec] = None # Optional transform step
[docs] def template_anvil_dir(self, anvil_dir: Path):
"""Template ANVIL_DIR in model and ensemble path fields."""
# Model paths are consumed by plain open(), not fsspec — use url_to_fs to
# strip any protocol prefix (e.g. file://) so the stored strings are valid
# local filesystem paths.
_, local_path = fsspec.url_to_fs(str(anvil_dir))
anvil_dir = Path(local_path)
self.model.template_anvil_dir(anvil_dir)
if self.ensemble is not None:
self.ensemble.template_anvil_dir(anvil_dir)
[docs]class ReportSpec(SpecBase):
"""Report specification."""
section_name: ClassVar[str] = "report"
eval: list[EvalSpec]
[docs]class AnvilSpecification(BaseModel):
"""Full specification for Anvil workflow."""
metadata: Metadata
data: DataSpec
procedure: ProcedureSpec
report: ReportSpec
@classmethod
def from_recipe(cls, yaml_path: PathLike, **storage_options):
"""Load specification from YAML recipe file."""
# Load YAML file
of = fsspec.open(yaml_path, "r", **storage_options)
with of as stream:
data = yaml.safe_load(stream)
# Parse parent protocol
parent = of.fs.unstrip_protocol(of.fs._parent(yaml_path))
# Instantiate specification with loaded data
instance = cls(**data)
# Set the anvil_dir
instance.data.template_anvil_dir(parent)
instance.procedure.template_anvil_dir(parent)
return instance
[docs] def to_recipe(self, path, **storage_options):
"""Write specification to YAML recipe file."""
# Open file stream
with fsspec.open(path, "w", **storage_options) as stream:
# Safe dump the model to stream
yaml.safe_dump(self.model_dump(), stream)
@classmethod
def from_multi_yaml(
cls,
metadata_yaml="metadata.yaml",
procedure_yaml="procedure.yaml",
data_yaml="data.yaml",
report_yaml="eval.yaml",
**storage_options,
):
"""Load specification from multiple YAML files."""
# Load YAML files
metadata = Metadata.from_yaml(metadata_yaml, **storage_options)
data = DataSpec.from_yaml(data_yaml, **storage_options)
procedure = ProcedureSpec.from_yaml(procedure_yaml, **storage_options)
report = ReportSpec.from_yaml(report_yaml, **storage_options)
# Instantiate the class with loaded data
return cls(metadata=metadata, data=data, procedure=procedure, report=report)
[docs] def to_multi_yaml(
self,
metadata_yaml="metadata.yaml",
procedure_yaml="procedure.yaml",
data_yaml="data.yaml",
report_yaml="eval.yaml",
**storage_options,
):
"""
Write specification to multiple YAML files.
Parameters
----------
metadata_yaml : str or PathLike, optional
The file path for the metadata YAML file. Default is 'metadata.yaml'.
procedure_yaml : str or PathLike, optional
The file path for the procedure YAML file. Default is 'procedure.yaml'.
data_yaml : str or PathLike, optional
The file path for the data YAML file. Default is 'data.yaml'.
report_yaml : str or PathLike, optional
The file path for the report YAML file. Default is 'eval.yaml'.
storage_options : dict, optional
Additional options to pass to the file system (e.g., for S3, GCS)
"""
# Write each section to its own YAML file
self.metadata.to_yaml(metadata_yaml, **storage_options)
self.data.to_yaml(data_yaml, **storage_options)
self.procedure.to_yaml(procedure_yaml, **storage_options)
self.report.to_yaml(report_yaml, **storage_options)
[docs] def to_workflow(self):
"""Convert the specification to a workflow object."""
logger.info("Making workflow from specification")
# Import here to avoid circular import
from openadmet.models.anvil.workflow import _DRIVER_TO_CLASS
# Pull driver from associated trainer to choose the correct workflow
trainer_class = self.procedure.train.to_class()
driver = _DRIVER_TO_CLASS[trainer_class._driver_type]
model_kwargs = {
"param_path": self.procedure.model.param_path,
"serial_path": self.procedure.model.serial_path,
"freeze_weights": self.procedure.model.freeze_weights,
}
ensemble_kwargs = (
{
"n_models": self.procedure.ensemble.n_models,
"calibration_method": self.procedure.ensemble.calibration_method,
"param_paths": self.procedure.ensemble.param_paths,
"serial_paths": self.procedure.ensemble.serial_paths,
"use_bagging": self.procedure.ensemble.use_bagging,
}
if self.procedure.ensemble
else {}
)
return driver(
metadata=self.metadata,
data_spec=self.data,
model=self.procedure.model.to_class(),
ensemble=self.procedure.ensemble.to_class()
if self.procedure.ensemble
else None,
transform=self.procedure.transform.to_class()
if self.procedure.transform
else None,
split=self.procedure.split.to_class(),
feat=self.procedure.feat.to_class(),
trainer=self.procedure.train.to_class(),
evals=[eval.to_class() for eval in self.report.eval],
model_kwargs=model_kwargs,
ensemble_kwargs=ensemble_kwargs,
)
[docs] def run(
self,
output_dir: PathLike = "anvil_training",
debug: bool = False,
tag: str = None,
):
"""Run the Anvil workflow from this specification."""
workflow = self.to_workflow()
result = workflow.run(output_dir=output_dir, debug=debug, tag=tag)
resolved_output_dir = workflow.resolved_output_dir or Path(output_dir)
resolved_output_dir.mkdir(parents=True, exist_ok=True)
provenance_spec = self.model_copy(deep=True)
if tag is not None:
provenance_spec.metadata.tag = tag
provenance_spec.to_recipe(resolved_output_dir / "anvil_recipe.yaml")
recipe_components = resolved_output_dir / "recipe_components"
recipe_components.mkdir(parents=True, exist_ok=True)
provenance_spec.to_multi_yaml(
metadata_yaml=recipe_components / "metadata.yaml",
procedure_yaml=recipe_components / "procedure.yaml",
data_yaml=recipe_components / "data.yaml",
report_yaml=recipe_components / "eval.yaml",
)
return result