Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 71 additions & 48 deletions serenityff/charge/gnn/utils/rdkit_helper.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,58 @@
from typing import List, Optional, Sequence

import torch
import numpy as np
import torch as pt
from rdkit import Chem

from serenityff.charge.gnn.utils import CustomData, MolGraphConvFeaturizer
from serenityff.charge.utils import Molecule


def mols_from_sdf(sdf_file: str, removeHs: Optional[bool] = False) -> Sequence[Molecule]:
"""Return a sequence of RDKit molecules read from an .sdf file.

:param sdf_file: Path to the .sdf file.
:param removeHs: Whether to remove hydrogens. Defaults to False.
:return: A sequence of RDKit molecule objects.
"""
Returns a Sequence of rdkit molecules read in from a .sdf file.
return Chem.SDMolSupplier(sdf_file, removeHs=removeHs)


def get_mol_prop_as_np_array(prop_name: Optional[str], mol: Chem.Mol, dtype: type = float) -> np.ndarray:
"""Get atomic properties from an RDKit molecule object as an array.

Args:
sdf_file (str): path to .sdf file.
removeHs (Optional[bool], optional): Wheter to remove Hydrogens. Defaults to False.
The property is expected to be a string of '|' separated numerical
values, one for each atom in the molecule.

Returns:
Sequence[Molecule]: rdkit mols.
:param prop_name: The name of the property to retrieve from the molecule.
:param mol: The RDKit molecule object.
:return: The atomic properties converted to a NumPy array.
:raises ValueError: If ``prop_name`` is None or if the property is not found in the molecule.
:raises TypeError: If any of the parsed property values are NaN or not convertible to float.
"""
return Chem.SDMolSupplier(sdf_file, removeHs=removeHs)
if prop_name is None:
raise ValueError("Property name can not be None.")
if not mol.HasProp(prop_name):
raise ValueError(f"Property {prop_name} not found in molecule.") # noqa E713
array = np.fromstring(mol.GetProp(prop_name), sep="|", dtype=dtype)
if np.isnan(array).any():
raise TypeError(f"Nan found in {prop_name}.")
return array


def get_mol_prop_as_pt_tensor(prop_name: Optional[str], mol: Chem.Mol) -> pt.Tensor:
"""Get atomic properties from an RDKit molecule object as a tensor.

The property is expected to be a string of '|' separated numerical
values, one for each atom in the molecule.

:param prop_name: The name of the property to retrieve from the molecule.
:param mol: The RDKit molecule object.
:return: The atomic properties converted to a PyTorch tensor.
:raises ValueError: If ``prop_name`` is None or if the property is not found in the molecule.
:raises TypeError: If any of the parsed property values are NaN or not convertible to float.
"""
return pt.from_numpy(get_mol_prop_as_np_array(prop_name=prop_name, mol=mol, dtype=np.float32))


def get_graph_from_mol(
Expand All @@ -39,59 +73,48 @@ def get_graph_from_mol(
],
no_y: Optional[bool] = False,
) -> Optional[CustomData]:
"""
Creates an pytorch_geometric Graph from an rdkit molecule.

Returns None if the property is not found or contains NaN.
The graph contains following features:
> Node Features:
> Atom Type (as specified in allowable set)
> formal_charge
> hybridization
> H acceptor_donor
> aromaticity
> degree
> Edge Features:
> Bond type
> is in ring
> is conjugated
> stereo
Args:
mol (Molecule): rdkit molecule
sdf_property_name (Optional[str]): Name of the property in the sdf file to be used for training.
allowable_set (Optional[List[str]], optional): List of atoms to be \
included in the feature vector. Defaults to \
[ "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "H", ].

Returns:
CustomData: pytorch geometric Data with .smiles as an extra attribute.
"""
"""Create a PyTorch Geometric graph from an RDKit molecule.

def get_mol_prop_as_torch_tensor(prop_name: Optional[str], mol: Molecule) -> torch.Tensor:
if prop_name is None:
raise ValueError("Property name can not be None when no_y == False.")
if not mol.HasProp(prop_name):
raise ValueError(f"Property {prop_name} not found in molecule.") # noqa E713
tensor = torch.tensor([float(x) for x in mol.GetProp(prop_name).split("|")], dtype=torch.float)
if torch.isnan(tensor).any():
raise TypeError(f"Nan found in {prop_name}.")
return tensor
Returns ``None`` if the specified property is not found or contains NaN.

The graph contains the following features:

**Node features**
- Atom type (as specified in the `allowable_set`)
- Formal charge
- Hybridization
- H acceptor/donor
- Aromaticity
- Degree

**Edge features**
- Bond type
- Is in ring
- Is conjugated
- Stereo information

:param mol: The RDKit molecule.
:param sdf_property_name: Name of the property in the SDF file to be used for training.
:param allowable_set: List of atoms to include in the feature vector. Defaults to
``["C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "H"]``.
:return: A PyTorch Geometric ``Data`` object with an additional ``.smiles`` attribute,
or ``None`` if the property is invalid.
"""
grapher = MolGraphConvFeaturizer(use_edges=True)
graph = grapher._featurize(mol, allowable_set).to_pyg_graph()
if no_y:
graph.y = torch.tensor(
graph.y = pt.tensor(
[0 for _ in mol.GetAtoms()],
dtype=torch.float,
dtype=pt.float,
)
else:
try:
graph.y = get_mol_prop_as_torch_tensor(sdf_property_name, mol)
graph.y = get_mol_prop_as_pt_tensor(sdf_property_name, mol)
except TypeError as exc:
print(exc)
return None

graph.batch = torch.tensor([0 for _ in mol.GetAtoms()], dtype=int)
graph.batch = pt.tensor([0 for _ in mol.GetAtoms()], dtype=int)
graph.molecule_charge = Chem.GetFormalCharge(mol)
graph.smiles = Chem.MolToSmiles(mol, canonical=True)
graph.sdf_idx = index
Expand Down
Empty file.
Empty file.
96 changes: 96 additions & 0 deletions tests/serenityff/charge/gnn/utils/test_rdkit_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
import pytest
import torch as pt
from rdkit import Chem

from serenityff.charge.gnn.utils.rdkit_helper import (
get_mol_prop_as_np_array,
get_mol_prop_as_pt_tensor,
)


@pytest.fixture
def sample_mol_with_prop():
"""Fixture for a sample RDKit molecule with a valid property."""
mol = Chem.MolFromSmiles("CCO") # Ethanol
mol.SetProp("test_prop", "1.0|2.5|-3.0")
return mol


@pytest.fixture
def sample_mol_with_nan_prop():
"""Fixture for a sample RDKit molecule with a property containing NaN."""
mol = Chem.MolFromSmiles("CCO")
mol.SetProp("test_prop_nan", "1.0|nan|3.0")
return mol


@pytest.fixture
def sample_mol_missing_prop():
"""Fixture for a sample RDKit molecule without the desired property."""
mol = Chem.MolFromSmiles("CCO")
return mol


def test_get_mol_prop_as_pt_tensor_success(sample_mol_with_prop):
"""Test successful retrieval of property as a tensor."""
expected = pt.tensor([1.0, 2.5, -3.0], dtype=pt.float)
result = get_mol_prop_as_pt_tensor("test_prop", sample_mol_with_prop)
assert isinstance(result, pt.Tensor)
assert pt.equal(result, expected)


def test_get_mol_prop_as_pt_tensor_raises_value_error_on_none_prop(
sample_mol_missing_prop,
):
"""Test ValueError is raised when prop_name is None."""
with pytest.raises(ValueError, match="Property name can not be None"):
get_mol_prop_as_pt_tensor(None, sample_mol_missing_prop)


def test_get_mol_prop_as_pt_tensor_raises_value_error_on_missing_prop(
sample_mol_missing_prop,
):
"""Test ValueError is raised when the property is not found."""
with pytest.raises(ValueError, match="Property missing_prop not found"):
get_mol_prop_as_pt_tensor("missing_prop", sample_mol_missing_prop)


def test_get_mol_prop_as_pt_tensor_raises_type_error_on_nan(
sample_mol_with_nan_prop,
):
"""Test TypeError is raised when NaN is in the property string."""
with pytest.raises(TypeError, match="Nan found in test_prop_nan"):
get_mol_prop_as_pt_tensor("test_prop_nan", sample_mol_with_nan_prop)


def test_get_mol_prop_as_np_array_success(sample_mol_with_prop):
"""Test successful retrieval of property as a numpy array."""
expected = np.array([1.0, 2.5, -3.0])
result = get_mol_prop_as_np_array("test_prop", sample_mol_with_prop)
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, expected)


def test_get_mol_prop_as_np_array_raises_value_error_on_none_prop(
sample_mol_missing_prop,
):
"""Test ValueError is raised when prop_name is None."""
with pytest.raises(ValueError, match="Property name can not be None"):
get_mol_prop_as_np_array(None, sample_mol_missing_prop)


def test_get_mol_prop_as_np_array_raises_value_error_on_missing_prop(
sample_mol_missing_prop,
):
"""Test ValueError is raised when the property is not found."""
with pytest.raises(ValueError, match="Property missing_prop not found"):
get_mol_prop_as_np_array("missing_prop", sample_mol_missing_prop)


def test_get_mol_prop_as_np_array_raises_type_error_on_nan(
sample_mol_with_nan_prop,
):
"""Test TypeError is raised when NaN is in the property string."""
with pytest.raises(TypeError, match="Nan found in test_prop_nan"):
get_mol_prop_as_np_array("test_prop_nan", sample_mol_with_nan_prop)
Loading