Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmarks/bench_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
"""

from bench_utils.loaders import load_pickle, load_sdf, load_smarts, load_smiles
from bench_utils.molprep import clone_mols_with_conformers, prep_mols
from bench_utils.molprep import clone_mols_with_conformers, embed_and_jitter, perturb_conformer, prep_mols
from bench_utils.timing import TimingResult, time_it

__all__ = [
"TimingResult",
"clone_mols_with_conformers",
"embed_and_jitter",
"load_pickle",
"load_sdf",
"load_smarts",
"load_smiles",
"perturb_conformer",
"prep_mols",
"time_it",
]
114 changes: 114 additions & 0 deletions benchmarks/bench_utils/molprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@

"""Molecule preparation helpers shared across nvMolKit benchmarks."""

import random
from functools import partial

from rdkit import Chem
from rdkit.Chem import rdDistGeom
from rdkit.Geometry import Point3D
from tqdm.contrib.concurrent import process_map

# Manually tuned so the per-conformer jitter recreates an ETKDGv3-like pairwise RMSD spread
JITTER_CENTER = 1.3
JITTER_SPREAD = 0.6


def prep_mols(
Expand Down Expand Up @@ -63,3 +73,107 @@ def clone_mols_with_conformers(mols: list[Chem.Mol]) -> list[Chem.RWMol]:
pristine input.
"""
return [Chem.RWMol(mol) for mol in mols]


def perturb_conformer(
conf: Chem.Conformer,
seed: int,
center: float = JITTER_CENTER,
spread: float = JITTER_SPREAD,
) -> None:
"""Apply per-atom uniform jitter to a conformer in place.

A single half-width is drawn for the conformer as ``center * (1 + spread *
U(-1, 1))`` and every x/y/z coordinate is then shifted by ``U(-half_width,
half_width)``. Drawing a distinct half-width per conformer (each call uses
a distinct ``seed``) gives a jittered ensemble a range of pairwise RMSDs
rather than a single structure-independent value.
"""
rng = random.Random(seed)
half_width = max(0.0, center * (1.0 + spread * rng.uniform(-1.0, 1.0)))
for atom_idx in range(conf.GetNumAtoms()):
pos = conf.GetAtomPosition(atom_idx)
conf.SetAtomPosition(
atom_idx,
Point3D(
pos.x + rng.uniform(-half_width, half_width),
pos.y + rng.uniform(-half_width, half_width),
pos.z + rng.uniform(-half_width, half_width),
),
)


def _embed_one(args_tuple: tuple[int, bytes], seed: int, add_hs: bool, min_atoms: int) -> bytes | None:
"""Embed a single ETKDGv3 conformer for one mol payload (multiprocessing worker)."""
idx, mol_bytes = args_tuple
mol = Chem.Mol(mol_bytes)
if mol.GetNumAtoms() < min_atoms:
return None
if add_hs:
mol = Chem.AddHs(mol)
params = rdDistGeom.ETKDGv3()
params.useRandomCoords = True
params.randomSeed = seed + idx
try:
conf_id = rdDistGeom.EmbedMolecule(mol, params=params)
except Exception:
return None
if conf_id < 0 or mol.GetNumConformers() == 0:
return None
if add_hs:
mol = Chem.RemoveHs(mol)
return mol.ToBinary()


def embed_and_jitter(
mols: list[Chem.Mol],
confs_per_mol: int,
seed: int,
num_workers: int = 1,
add_hs: bool = False,
min_atoms: int = 1,
desc: str = "Embedding base conformers",
) -> list[Chem.Mol]:
"""Embed one ETKDGv3 base conformer per mol in parallel, then jitter to ``confs_per_mol``.

The embed step runs across mols via ``process_map``; the jitter step is
in-process and serial (cheap). Mols whose base embedding fails are
dropped with a printed count. When ``add_hs`` is true, hydrogens are
added before embedding and stripped from the returned mol.
"""
if not mols:
return []
if confs_per_mol < 1:
raise ValueError(f"confs_per_mol must be >= 1, got {confs_per_mol}")

workers = max(1, num_workers)
binaries = [(i, mol.ToBinary()) for i, mol in enumerate(mols)]
embedded_binaries = process_map(
partial(_embed_one, seed=seed, add_hs=add_hs, min_atoms=min_atoms),
binaries,
max_workers=workers,
chunksize=max(1, len(binaries) // (workers * 8) or 1),
Comment thread
scal444 marked this conversation as resolved.
desc=desc,
)

out: list[Chem.Mol] = []
drop_count = 0
for raw in embedded_binaries:
if raw is None:
drop_count += 1
continue
out.append(Chem.Mol(raw))
if drop_count > 0:
print(f" Dropped {drop_count} molecules during embedding (no conformer generated)")

if confs_per_mol > 1:
for mol_idx, mol in enumerate(out):
base_conf_id = mol.GetConformer().GetId()
base_conf = mol.GetConformer(base_conf_id)
for conf_idx in range(1, confs_per_mol):
new_conf = Chem.Conformer(base_conf)
perturb_conformer(new_conf, seed=seed + mol_idx * confs_per_mol + conf_idx)
mol.AddConformer(new_conf, assignId=True)
perturb_conformer(mol.GetConformer(base_conf_id), seed=seed + mol_idx * confs_per_mol)

return out
16 changes: 12 additions & 4 deletions benchmarks/conformer_rmsd_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import numpy as np
import torch
from bench_utils import perturb_conformer
from benchmark_timing import time_it
from rdkit import Chem
from rdkit.Chem import AllChem, rdDistGeom
Expand Down Expand Up @@ -87,12 +88,19 @@ def run_benchmark(smiles, num_confs_list, seed=42):
params = rdDistGeom.ETKDGv3()
params.randomSeed = seed
params.useRandomCoords = True
rdDistGeom.EmbedMultipleConfs(mol, numConfs=num_confs, params=params)
actual_confs = mol.GetNumConformers()

if actual_confs < 2:
if rdDistGeom.EmbedMolecule(mol, params=params) < 0:
print(f"{num_confs:>8} {'skipped (embedding failed)':>50}")
continue
Comment thread
scal444 marked this conversation as resolved.
if num_confs < 2:
print(f"{num_confs:>8} {'skipped (need >= 2 confs for RMSD)':>50}")
continue
base_conf_id = mol.GetConformer().GetId()
for conf_idx in range(1, num_confs):
new_conf = Chem.Conformer(mol.GetConformer(base_conf_id))
perturb_conformer(new_conf, seed=seed + conf_idx)
mol.AddConformer(new_conf, assignId=True)
perturb_conformer(mol.GetConformer(base_conf_id), seed=seed)
actual_confs = mol.GetNumConformers()

no_h = Chem.RemoveHs(mol)
n_pairs = actual_confs * (actual_confs - 1) // 2
Expand Down
29 changes: 2 additions & 27 deletions benchmarks/ff_optimize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import torch
from bench_utils import (
clone_mols_with_conformers,
embed_and_jitter,
load_pickle,
load_sdf,
load_smiles,
Expand All @@ -53,32 +54,6 @@
OPTUNA_AVAILABLE = nv_autotune.is_available()


def _embed_conformers(mols: list[Chem.Mol], confs_per_mol: int, seed: int) -> list[Chem.Mol]:
"""Generate ``confs_per_mol`` conformers per molecule using RDKit ETKDGv3.

Molecules where embedding fails to produce at least one conformer are
dropped; a count is printed.
"""
params = rdDistGeom.ETKDGv3()
params.useRandomCoords = True
params.randomSeed = seed

embedded: list[Chem.Mol] = []
drop_count = 0
for mol in mols:
try:
conf_ids = rdDistGeom.EmbedMultipleConfs(mol, numConfs=confs_per_mol, params=params)
if not conf_ids:
drop_count += 1
continue
embedded.append(mol)
except Exception:
drop_count += 1
if drop_count > 0:
print(f" Dropped {drop_count} molecules during embedding (no conformer generated)")
return embedded


def _flatten_energies(per_mol: list[list[float]]) -> list[float]:
"""Flatten ``[[e0, e1, ...], [e0, e1, ...], ...]`` returned by nvmolkit."""
flat: list[float] = []
Expand Down Expand Up @@ -398,7 +373,7 @@ def main() -> None:
print(f" {len(mols)} molecules ready")

print(f"\nEmbedding {args.confs_per_mol} conformer(s) per molecule with RDKit ETKDGv3...")
mols = _embed_conformers(mols, args.confs_per_mol, args.seed)
mols = embed_and_jitter(mols, args.confs_per_mol, seed=args.seed, num_workers=args.rdkit_threads)
if not mols:
print("Error: No molecules retained after embedding")
sys.exit(1)
Expand Down
Loading
Loading