Source code for openadmet.models.features.chemprop

"""ChemProp featurizer implementation."""

from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Union

import numpy as np
import pandas as pd

from openadmet.models.features.feature_base import DeepLearningFeaturizer, featurizers


# we vendor this from chemprop so that we can pass custom samplers
# taken directly from https://github.com/chemprop/chemprop/blob/main/chemprop/data/dataloader.py
def _vendor_build_dataloader(
    dataset,
    batch_size: int = 64,
    num_workers: int = 0,
    class_balance: bool = False,
    sampler: Any = None,
    seed: int | None = None,
    shuffle: bool = True,
    **kwargs,
):
    r"""
    Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`.

    Parameters
    ----------
    dataset : MoleculeDataset | ReactionDataset | MulticomponentDataset
        The dataset containing the molecules or reactions to load.
    batch_size : int, default=64
        the batch size to load.
    num_workers : int, default=0
        the number of workers used to build batches.
    class_balance : bool, default=False
        Whether to perform class balancing (i.e., use an equal number of positive and negative
        molecules). Class balance is only available for single task classification datasets. Set
        shuffle to True in order to get a random subset of the larger class.
    sampler : torch.utils.data.Sampler, optional
        Custom sampler to use for loading data (default is None). If this is specified, it
        overrides class_balance and shuffle.
    seed : int, optional
        Random seed for shuffling and class balancing (default is None).
    shuffle : bool, default=True
        Whether to shuffle the data at every epoch. If a sampler is specified, this is ignored
        (i.e., the sampler determines the shuffling). If class_balance is True, this is also ignored
        (i.e., class balancing determines the shuffling).
    **kwargs
        Additional keyword arguments passed to the DataLoader.

    Returns
    -------
    DataLoader
        A PyTorch DataLoader for the given MoleculeDataset, ReactionDataset, or MulticomponentDataset.

    """
    from chemprop.data import MulticomponentDataset
    from chemprop.data.collate import collate_batch, collate_multicomponent
    from chemprop.data.samplers import ClassBalanceSampler, SeededSampler
    from torch.utils.data import DataLoader

    if sampler is not None:
        if class_balance:
            sampler = ClassBalanceSampler(dataset.Y, seed, shuffle)
        elif shuffle and seed is not None:
            sampler = SeededSampler(len(dataset), seed)
        else:
            sampler = None

    if isinstance(dataset, MulticomponentDataset):
        collate_fn = collate_multicomponent
    else:
        collate_fn = collate_batch

    # Drop last batch of size 1 to avoid issues with batch normalization
    if len(dataset) % batch_size == 1:
        drop_last = True
    else:
        drop_last = False

    return DataLoader(
        dataset,
        batch_size,
        sampler is None and shuffle,
        sampler,
        num_workers=num_workers,
        collate_fn=collate_fn,
        drop_last=drop_last,
        **kwargs,
    )


[docs]@featurizers.register("ChemPropFeaturizer") class ChemPropFeaturizer(DeepLearningFeaturizer): """ ChemPropFeaturizer featurizer for molecules, relies on chemprop. Parameters ---------- normalize_targets : bool, optional Whether to normalize the targets using StandardScaler, by default True n_jobs : int, optional Number of parallel workers to use, by default 4 batch_size : int, optional Batch size for the DataLoader, by default 128 shuffle : bool, optional Whether to shuffle the data in the DataLoader, by default False """ normalize_targets: bool = True n_jobs: int = 4 batch_size: int = 128 shuffle: bool = False def _prepare(self): """Prepare the featurizer."""
[docs] def featurize( self, smiles: Iterable[str], y: Iterable[Any] = None ) -> tuple[ DataLoader, np.ndarray, StandardScaler, MoleculeDataset | ReactionDataset | MulticomponentDataset, ]: """ Featurize a list of SMILES strings. Parameters ---------- smiles : Iterable[str] List or iterable of SMILES strings to featurize. y : Iterable[Any], optional Target values corresponding to the SMILES strings. Returns ------- tuple Tuple containing: - DataLoader: PyTorch DataLoader for the dataset. - np.ndarray: Array of indices corresponding to the original input. - StandardScaler: Scaler used for any scaling during featurization. - Union[MoleculeDataset, ReactionDataset, MulticomponentDataset]: PyTorch Dataset containing the features and targets. """ from chemprop.data import MoleculeDatapoint, MoleculeDataset if y is not None: # if a pandas dataframe or series if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series): y = y.to_numpy() y = y.reshape(-1, 1) if y.ndim == 1 else y dataset = MoleculeDataset( [MoleculeDatapoint.from_smi(smi, y_) for smi, y_ in zip(smiles, y)] ) if self.normalize_targets: scaler = dataset.normalize_targets() else: scaler = None else: dataset = MoleculeDataset( [MoleculeDatapoint.from_smi(smi) for smi in smiles] ) scaler = None dataloader = self.dataset_to_dataloader( dataset, num_workers=self.n_jobs, shuffle=self.shuffle, batch_size=self.batch_size, ) # Need to also return an index of the original input for which the features were computed indices = np.arange(len(smiles)) return dataloader, indices, scaler, dataset
[docs] @staticmethod def dataset_to_dataloader( dataset: MoleculeDataset, batch_size: int = 128, shuffle: bool = False, sampler=None, **kwargs, ) -> DataLoader: """ Convert a MoleculeDataset to a PyTorch DataLoader. Parameters ---------- dataset : MoleculeDataset The dataset containing the molecules to load. batch_size : int, optional Number of samples per batch to load (default is 128). shuffle : bool, optional Whether to shuffle the data at every epoch (default is False). sampler : torch.utils.data.Sampler, optional Custom sampler to use for loading data (default is None). **kwargs Additional keyword arguments passed to the DataLoader. Returns ------- DataLoader A PyTorch DataLoader for the given MoleculeDataset. """ return _vendor_build_dataloader( dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, **kwargs, )
[docs] def make_new(self) -> ChemPropFeaturizer: """Copy parameters to a new ChemPropFeaturizer instance.""" return self.__class__(**self.dict())