"""Scaffold-based data splitting implementations."""
import logging
from sklearn.model_selection import train_test_split
from splito import MaxDissimilaritySplit, PerimeterSplit, ScaffoldSplit
import numpy as np
import pandas as pd
from openadmet.models.split.split_base import SplitterBase, splitters
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
[docs]@splitters.register("ScaffoldSplitter")
class ScaffoldSplitter(SplitterBase):
"""Splits the data based on the scaffold of the molecules."""
[docs] def split(self, X, y):
"""
Split the data into train, validation, and test sets.
Parameters
----------
X : Iterable[str]
List or iterable of SMILES strings to split.
y : Iterable[float] or pd.Series
List or iterable of target values corresponding to the SMILES strings.
Returns
-------
tuple
Tuple containing:
- X_train: Training set SMILES strings.
- X_val: Validation set SMILES strings (or None if val_size=0).
- X_test: Test set SMILES strings (or None if test_size=0).
- y_train: Training set target values.
- y_val: Validation set target values (or None if val_size=0).
- y_test: Test set target values (or None if test_size=0).
"""
logging.warning("ScaffoldSplitter is not available for cross-validation.")
# No test set requested
if self.test_size == 0:
# Split into train and val
splitter = ScaffoldSplit(
smiles=X,
n_jobs=-1,
train_size=None,
test_size=int(self.val_size * X.shape[0]),
random_state=self.random_state,
)
groups = None
train_idx, val_idx = next(splitter.split(X=X))
return (
safe_index(X, train_idx),
safe_index(X, val_idx),
None,
safe_index(y, train_idx),
safe_index(y, val_idx),
None,
groups,
)
# Split into train+val and test
splitter = ScaffoldSplit(
smiles=X,
n_jobs=-1,
train_size=None,
test_size=int(self.test_size * X.shape[0]),
random_state=self.random_state,
)
groups = None
train_val_idx, test_idx = next(splitter.split(X=X))
# No validation set requested, return train(+val) and test sets
if self.val_size == 0:
return (
safe_index(X, train_val_idx),
None,
safe_index(X, test_idx),
safe_index(y, train_val_idx),
None,
safe_index(y, test_idx),
groups,
)
# Split train+val into train and val sets
X_train, X_val, y_train, y_val = train_test_split(
safe_index(X, train_val_idx),
safe_index(y, train_val_idx),
train_size=None,
test_size=int(self.val_size * X.shape[0]),
random_state=self.random_state,
)
# Return train, val, and test sets
return (
X_train,
X_val,
safe_index(X, test_idx),
y_train,
y_val,
safe_index(y, test_idx),
groups,
)
[docs]@splitters.register("PerimeterSplitter")
class PerimeterSplitter(SplitterBase):
"""Splits the data based on the perimeter of the molecules."""
[docs] def split(self, X, y):
"""
Split the data into train, validation, and test sets.
Parameters
----------
X : Iterable[str]
List or iterable of SMILES strings to split.
y : Iterable[float] or pd.Series
List or iterable of target values corresponding to the SMILES strings.
Returns
-------
tuple
Tuple containing:
- X_train: Training set SMILES strings.
- X_val: Validation set SMILES strings (or None if val_size=0).
- X_test: Test set SMILES strings (or None if test_size=0).
- y_train: Training set target values.
- y_val: Validation set target values (or None if val_size=0).
- y_test: Test set target values (or None if test_size=0).
"""
logging.warning("PerimeterSplitter is not available for cross-validation.")
# No test set requested
if self.test_size == 0:
# Split into train and val
splitter = PerimeterSplit(
smiles=X,
n_jobs=-1,
train_size=None,
test_size=int(self.val_size * X.shape[0]),
random_state=self.random_state,
)
groups = None
train_idx, val_idx = next(splitter.split(X=X))
return (
safe_index(X, train_idx),
safe_index(X, val_idx),
None,
safe_index(y, train_idx),
safe_index(y, val_idx),
None,
groups,
)
# Split into train+val and test
splitter = PerimeterSplit(
n_jobs=-1,
train_size=None,
test_size=int(self.test_size * X.shape[0]),
random_state=self.random_state,
)
groups = None
train_val_idx, test_idx = next(splitter.split(X=X))
# No validation set requested, return train(+val) and test sets
if self.val_size == 0:
return (
safe_index(X, train_val_idx),
None,
safe_index(X, test_idx),
safe_index(y, train_val_idx),
None,
safe_index(y, test_idx),
groups,
)
# Split train+val into train and val sets using sklearn
X_train, X_val, y_train, y_val = train_test_split(
safe_index(X, train_val_idx),
safe_index(y, train_val_idx),
train_size=None,
test_size=int(self.val_size * X.shape[0]),
random_state=self.random_state,
)
# Return train, val, and test sets
return (
X_train,
X_val,
safe_index(X, test_idx),
y_train,
y_val,
safe_index(y, test_idx),
groups,
)
[docs]@splitters.register("MaxDissimilaritySplitter")
class MaxDissimilaritySplitter(SplitterBase):
"""Splits the data based on maximum dissimilarity."""
[docs] def split(self, X, y):
"""
Split the data into train, validation, and test sets.
Parameters
----------
X : Iterable[str]
List or iterable of SMILES strings to split.
y : Iterable[float] or pd.Series
List or iterable of target values corresponding to the SMILES strings.
Returns
-------
tuple
Tuple containing:
- X_train: Training set SMILES strings.
- X_val: Validation set SMILES strings (or None if val_size=0).
- X_test: Test set SMILES strings (or None if test_size=0).
- y_train: Training set target values.
- y_val: Validation set target values (or None if val_size=0).
- y_test: Test set target values (or None if test_size=0).
"""
logging.warning(
"MaxDissimilaritySplitter is not available for cross-validation."
)
# No test set requested
if self.test_size == 0:
# Split into train and val
splitter = MaxDissimilaritySplit(
smiles=X,
n_jobs=-1,
train_size=None,
test_size=int(self.val_size * X.shape[0]),
random_state=self.random_state,
)
groups = None
train_idx, val_idx = next(splitter.split(X=X))
return (
safe_index(X, train_idx),
safe_index(X, val_idx),
None,
safe_index(y, train_idx),
safe_index(y, val_idx),
None,
groups,
)
# Split into train+val and test
splitter = MaxDissimilaritySplit(
n_jobs=-1,
train_size=None,
test_size=int(self.test_size * X.shape[0]),
random_state=self.random_state,
)
groups = None
train_val_idx, test_idx = next(splitter.split(X=X))
# No validation set requested, return train(+val) and test sets
if self.val_size == 0:
return (
safe_index(X, train_val_idx),
None,
safe_index(X, test_idx),
safe_index(y, train_val_idx),
None,
safe_index(y, test_idx),
groups,
)
# Split train+val into train and val sets using sklearn
X_train, X_val, y_train, y_val = train_test_split(
safe_index(X, train_val_idx),
safe_index(y, train_val_idx),
train_size=None,
test_size=int(self.val_size * X.shape[0]),
random_state=self.random_state,
)
# Return train, val and test sets
return (
X_train,
X_val,
safe_index(X, test_idx),
y_train,
y_val,
safe_index(y, test_idx),
groups,
)
[docs]def safe_index(data, idx):
"""
Correct indexing depending on whether X and y are numpy arrays or pandas series/dataframes.
Parameters
----------
data : nd.array, list, pd.Series, or pd.DataFrame
X or y data
idx : list
list of integers (positional indices)
Returns
-------
nd.array or pd.Series
indexed data
"""
if isinstance(data, (np.ndarray, list)):
return data[idx]
elif isinstance(data, (pd.Series, pd.DataFrame)):
return data.iloc[idx]
else:
raise TypeError(f"Unsupported data type for indexing: {type(data)}")