Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions benchmarks/bench_utils/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ def load_pickle(filepath: str, max_count: int = 0, seed: int | None = None) -> l
seed: Optional seed for the sampling RNG.

Returns:
List of parsed RDKit molecules.
List of parsed RDKit molecules. The list is always shuffled
(deterministic with ``seed``) so benches that consume a head slice
get a representative cross-section rather than file-order bias.
"""
with open(filepath, "rb") as fh:
binary_mols = pickle.load(fh)
rng = random.Random(seed)
if max_count > 0 and len(binary_mols) > max_count:
binary_mols = random.Random(seed).sample(binary_mols, max_count)
binary_mols = rng.sample(binary_mols, max_count)
else:
binary_mols = list(binary_mols)
rng.shuffle(binary_mols)
mols = process_map(
_mol_from_binary,
binary_mols,
Expand Down Expand Up @@ -95,6 +101,10 @@ def load_smiles(
does not have to fit in memory and only the sampled SMILES are parsed. A
10% buffer is read past ``max_count`` to absorb parse failures, after which
the result is trimmed back to ``max_count``.

The returned list is always shuffled (deterministic with ``seed``) so
benches that consume a head slice get a representative cross-section
rather than file-order bias (some upstream files are sorted by size).
"""
read_limit = int(max_count * 1.1) if max_count > 0 else 0
rng = random.Random(seed)
Expand Down Expand Up @@ -127,6 +137,8 @@ def load_smiles(

if max_count > 0 and len(mols) > max_count:
mols = rng.sample(mols, max_count)
else:
rng.shuffle(mols)

print(f" Loaded {len(mols)} molecules from {filepath}")
return mols
Expand Down Expand Up @@ -168,7 +180,12 @@ def load_sdf(
removeHs: bool = False,
sanitize: bool = True,
) -> list[Chem.Mol]:
"""Load molecules from an SDF file with optional reservoir sampling."""
"""Load molecules from an SDF file with optional reservoir sampling.

The returned list is always shuffled (deterministic with ``seed``) so
benches that consume a head slice get a representative cross-section
rather than file-order bias (some upstream files are sorted by size).
"""
supplier = Chem.SDMolSupplier(filepath, removeHs=removeHs, sanitize=sanitize)
read_limit = int(max_count * 1.1) if max_count > 0 else 0
rng = random.Random(seed)
Expand Down Expand Up @@ -199,6 +216,8 @@ def load_sdf(

if max_count > 0 and len(mols) > max_count:
mols = rng.sample(mols, max_count)
else:
rng.shuffle(mols)

if parse_failures > 0:
print(f" ({parse_failures} parse failures)")
Expand Down
9 changes: 4 additions & 5 deletions benchmarks/butina_clustering_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import numpy as np
import pandas as pd
import torch
from bench_utils import load_smiles
from benchmark_timing import time_it
from rdkit import DataStructs
from rdkit.Chem import AllChem, MolFromSmiles
from rdkit.Chem import AllChem
from rdkit.DataStructs import BulkTanimotoSimilarity
from rdkit.ML.Cluster.Butina import ClusterData

Expand Down Expand Up @@ -163,6 +164,7 @@ def bench_nvmol_with_tanimoto(fps, threshold, neighborlist_max_size):
)
parser.add_argument("--cutoff", type=float, default=None, help="Run only this cutoff value")
parser.add_argument("--runs", type=int, default=3, help="Number of timed repetitions (default: 3)")
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling SMILES (default: 42)")
parser.add_argument(
"-o", "--output", type=str, default="results.csv", help="Output CSV file path (default: results.csv)"
)
Expand Down Expand Up @@ -207,10 +209,7 @@ def bench_nvmol_with_tanimoto(fps, threshold, neighborlist_max_size):

max_size = max(e["size"] for e in run_plan)

with open(args.input_smiles_file, "r") as f:
smis = [line.strip() for line in f.readlines()]
mols = [MolFromSmiles(smi, sanitize=True) for smi in smis[: max_size + 100]]
mols = [mol for mol in mols if mol is not None]
mols = load_smiles(args.input_smiles_file, max_count=max_size + 100, sanitize=True, seed=args.seed)

if include_tanimoto and len(mols) < max_size:
print(
Expand Down
14 changes: 6 additions & 8 deletions benchmarks/cross_similarity_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
import pyperf
import torch
from rdkit.Chem import MolFromSmiles, rdFingerprintGenerator
from bench_utils import load_smiles
from rdkit.Chem import rdFingerprintGenerator
from rdkit.DataStructs import BulkCosineSimilarity, BulkTanimotoSimilarity

from nvmolkit.similarity import crossCosineSimilarity, crossTanimotoSimilarity
from nvmolkit.fingerprints import MorganFingerprintGenerator
from nvmolkit.similarity import crossCosineSimilarity, crossTanimotoSimilarity


SIZES = [2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000, 20000, 24000, 28000, 32000]
Expand All @@ -45,20 +45,18 @@ def nvmolkit_sim_gpu_only(fps, sim_type):
runner = pyperf.Runner(min_time=0.01, values=3, processes=1, loops=3)
runner.metadata["description"] = "Cross Similarity benchmark"
runner.argparser.add_argument(
"--input", type=str, default="data/benchmark_smiles.csv", help="Path to input SMILES CSV file"
"--input", type=str, default="data/benchmark_smiles.csv", help="Path to input SMILES file (.smi/.csv/.cxsmiles)"
)
runner.argparser.add_argument("--cosine", action="store_true", help="Include cosine similarity benchmarks")
runner.argparser.add_argument("--seed", type=int, default=42, help="Random seed for sampling SMILES (default: 42)")
args = runner.parse_args()

sim_types = ("tanimoto", "cosine") if args.cosine else ("tanimoto",)
fpsize = 1024
max_size = max(SIZES)
default_values = runner.args.values

df = pd.read_csv(args.input)
smis = df.iloc[:, 0].to_list()
mols = [MolFromSmiles(smi) for smi in smis]
mols = [mol for mol in mols if mol is not None]
mols = load_smiles(args.input, max_count=max_size, seed=args.seed)
if not mols:
raise ValueError(f"No molecules parsed from {args.input}")
while len(mols) < max_size:
Expand Down
135 changes: 3 additions & 132 deletions benchmarks/substruct_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@

import argparse
import gc
import pickle
import random
import sys
from functools import partial
from multiprocessing import Pool
from typing import Callable, Iterator
from typing import Callable

import nvtx
import pandas as pd
from bench_utils import load_pickle, load_smarts, load_smiles
from benchmark_timing import time_it as _time_it
from nvmolkit import autotune as nv_autotune
from nvmolkit.substructure import (
Expand All @@ -71,9 +70,8 @@
getSubstructMatches,
hasSubstructMatch,
)
from rdkit import Chem, RDLogger
from rdkit import Chem
from rdkit.Chem import rdSubstructLibrary
from tqdm.contrib.concurrent import process_map

OPTUNA_AVAILABLE = nv_autotune.is_available()

Expand All @@ -84,133 +82,6 @@ def time_it(func: Callable, runs: int = 1, gpu_sync: bool = False) -> tuple[floa
return result.mean_ms, result.std_ms


def load_pickle(filepath: str, max_count: int = 0, seed: int | None = None) -> list[Chem.Mol]:
"""Load molecules from a pickled file containing binary mol data.

When ``max_count > 0``, a uniform random sample of binary mols is drawn.
"""
with open(filepath, "rb") as f:
binary_mols = pickle.load(f)
if max_count > 0 and len(binary_mols) > max_count:
binary_mols = random.Random(seed).sample(binary_mols, max_count)
mols = process_map(
_mol_from_binary,
binary_mols,
desc="Unpickling molecules",
chunksize=1000,
)
print(f" Loaded {len(mols)} molecules from {filepath}")
return mols


def _mol_from_binary(binary_mol: bytes) -> Chem.Mol:
"""Load a molecule from RDKit binary format."""
return Chem.Mol(binary_mol)


def _parse_smiles(smi: str, sanitize: bool) -> Chem.Mol | None:
"""Parse a single SMILES string."""
return Chem.MolFromSmiles(smi, sanitize=sanitize)


def _iter_smiles_tokens(filepath: str, sanitize: bool) -> Iterator[str]:
"""Yield SMILES tokens from a file, skipping blanks/comments and a parse-failing first line.

The first non-comment line is parsed quietly; if it fails to parse it is treated as a header
and dropped, matching the original loader's behavior.
"""
with open(filepath, "r") as f:
first_data_seen = False
for line in f:
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
smi = stripped.split()[0]
if not first_data_seen:
first_data_seen = True
RDLogger.DisableLog("rdApp.*")
mol = Chem.MolFromSmiles(smi, sanitize=sanitize)
RDLogger.EnableLog("rdApp.*")
if mol is None:
continue
yield smi


def load_smiles(filepath: str, max_count: int = 0, sanitize: bool = True, seed: int | None = None) -> list[Chem.Mol]:
"""Load and parse molecules from a SMILES file.

When ``max_count > 0``, reservoir sampling draws a uniform random sample of lines in a single
streaming pass (with a 10% buffer to absorb parse failures) so the file isn't fully loaded into
memory and only the sampled SMILES are parsed.
"""
# Use a 10% buffer to account for potential parse failures
# "On parse failures continue down the file. Load 10% more molecules than needed"
read_limit = int(max_count * 1.1) if max_count > 0 else 0

if read_limit > 0:
rng = random.Random(seed)
reservoir: list[str] = []
for index, smi in enumerate(_iter_smiles_tokens(filepath, sanitize)):
if index < read_limit:
reservoir.append(smi)
else:
replace_index = rng.randint(0, index)
if replace_index < read_limit:
reservoir[replace_index] = smi
smiles_list = reservoir
else:
smiles_list = list(_iter_smiles_tokens(filepath, sanitize))

mols: list[Chem.Mol] = []
if smiles_list:
parse_func = partial(_parse_smiles, sanitize=sanitize)
parsed = process_map(parse_func, smiles_list, desc="Parsing molecules", chunksize=1000)

parse_failures = 0
for mol in parsed:
if mol is None:
parse_failures += 1
else:
mols.append(mol)

if parse_failures > 0:
print(f" ({parse_failures} parse failures)")

# Trim to exactly max_count if we have more than requested
if max_count > 0 and len(mols) > max_count:
mols = mols[:max_count]

print(f" Loaded {len(mols)} molecules from {filepath}")
return mols


def load_smarts(filepath: str, max_count: int = 0) -> tuple[list[Chem.Mol], list[str]]:
"""Load and parse query patterns from a SMARTS file."""
queries = []
smarts_list = []
parse_failures = 0

with open(filepath, "r") as f:
for line in f:
if max_count > 0 and len(queries) >= max_count:
break
line = line.strip()
if not line or line.startswith("#"):
continue
smarts = line.split()[0]
query = Chem.MolFromSmarts(smarts)
if query is None:
parse_failures += 1
continue
queries.append(query)
smarts_list.append(smarts)

print(f" Loaded {len(queries)} SMARTS patterns from {filepath}")
if parse_failures > 0:
print(f" ({parse_failures} parse failures)")
return queries, smarts_list


_worker_queries = None
_worker_params = None

Expand Down
Loading