Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Auto-select the pkasolver devShell based on available GPU hardware.
# Requires direnv (https://direnv.net) and nix-direnv.
#
# Install nix-direnv once:
# nix profile install nixpkgs#nix-direnv
# echo 'source $HOME/.nix-profile/share/nix-direnv/direnvrc' >> ~/.direnvrc
#
# Then allow this file once:
# direnv allow

_detect_gpu() {
# CUDA: check /dev/dxg (WSL2) or /dev/nvidia0 (bare metal)
if [ -e /dev/dxg ] || [ -e /dev/nvidia0 ]; then
echo "cuda"
return
fi
# CUDA: fallback — nvidia-smi reachable and functional
# Add WSL lib path so nvidia-smi can find libnvidia-ml.so
if LD_LIBRARY_PATH="/usr/lib/wsl/lib:$LD_LIBRARY_PATH" \
command -v nvidia-smi &>/dev/null && \
LD_LIBRARY_PATH="/usr/lib/wsl/lib:$LD_LIBRARY_PATH" \
nvidia-smi &>/dev/null; then
echo "cuda"
return
fi
# ROCm: /dev/kfd is the AMD GPU compute device node
if [ -e /dev/kfd ]; then
echo "rocm"
return
fi
echo "cpu"
}

PKASOLVER_SHELL="$(_detect_gpu)"
echo "pkasolver: detected shell → $PKASOLVER_SHELL"
use flake ".#$PKASOLVER_SHELL"

# Set up the Python venv after the flake shell is loaded.
# uv is now on PATH from the flake. Re-sync when deps change.
watch_file pyproject.toml uv.lock
if [ ! -f .venv/.direnv-sync-done ] || \
[ pyproject.toml -nt .venv/.direnv-sync-done ] || \
[ uv.lock -nt .venv/.direnv-sync-done ]; then
echo "pkasolver: running uv sync..."
uv sync --extra dev --extra cpu --quiet && touch .venv/.direnv-sync-done
fi

# direnv cannot use `source .venv/bin/activate` — add venv to PATH directly.
PATH_add .venv/bin
export VIRTUAL_ENV="$PWD/.venv"
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ ENV/
# profraw files from LLVM? Unclear exactly what triggers this
# There are reports this comes from LLVM profiling, but also Xcode 9.
*profraw

# direnv
.direnv/

# legacy versioneer artifact
pkasolver/_version.py
.venv/
7 changes: 3 additions & 4 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
include LICENSE
include MANIFEST.in
include versioneer.py
include LICENSE.md
include README.md

graft pkasolver
global-exclude *.py[cod] __pycache__ *.so
global-exclude *.py[cod] __pycache__ *.so
55 changes: 28 additions & 27 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
name: pka_prediction
channels:
- pytorch
- conda-forge
- defaults
- pytorch
- conda-forge
- defaults
dependencies:
- rdkit
- matplotlib
- pytorch
- cudatoolkit
- torchvision
- torchaudio
- python
# Testing
- pytest
- pytest-cov
- codecov
- numpy
- scipy
- tqdm
- svgutils
- cairosvg
- ipython
- pip
- pip:
- torch-geometric==2.0.1
- torch-sparse
- torch-scatter
- molvs
- chembl_webresource_client
- python>=3.10
- rdkit>=2023.3
- matplotlib>=3.7
- numpy>=1.24
- scipy>=1.10
- tqdm>=4.65
- svgutils>=0.3.4
- cairosvg>=2.7
- ipython>=8.0
- jupyter
# Testing
- pytest>=7.4
- pytest-cov>=4.1
# PyTorch - install CPU build by default; for GPU see README
- pytorch>=2.0
- cpuonly # remove this line and install pytorch-cuda=12.1 for CUDA
- pip
- pip:
- torch-geometric>=2.3
- torch-scatter
- torch-sparse
- molvs>=0.1.1
- click>=8.1
- setuptools-scm>=8
- pkasolver # installs this package
61 changes: 61 additions & 0 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

110 changes: 110 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
description = "pkasolver — microstate pKa prediction via Graph Neural Networks";

inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05";
flake-parts.url = "github:hercules-ci/flake-parts";
};

outputs = inputs@{ flake-parts, ... }:
flake-parts.lib.mkFlake { inherit inputs; } {
systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];

perSystem = { pkgs, system, ... }:
let
python = pkgs.python311;

# Separate nixpkgs instance with allowUnfree for CUDA packages.
pkgs-unfree = import inputs.nixpkgs {
inherit system;
config.allowUnfree = true;
};

# System libraries required by rdkit, cairosvg, svgutils, and the
# X11 drawing backend.
sharedLibDeps = with pkgs; [
cairo
pango
glib
zlib
stdenv.cc.cc.lib
xorg.libXrender
xorg.libX11
xorg.libXext
];

# Shell hook shared across all variants.
# Sets LD_LIBRARY_PATH; runs uv sync only in interactive shells
# (not during direnv's environment export phase).
commonHook = ''
export LD_LIBRARY_PATH="${pkgs.lib.makeLibraryPath sharedLibDeps}:$LD_LIBRARY_PATH"
export UV_PYTHON="${python}/bin/python"
# Only run uv sync in interactive shells, not during direnv export.
# direnv sets IN_NIX_SHELL=impure during its eval phase but PS1 is unset.
if [ -n "''${PS1:-}" ]; then
uv sync --extra dev --extra cpu --quiet
source .venv/bin/activate
echo "Ready. Python: $(python --version), pkasolver: $(python -c 'import pkasolver; print(pkasolver.__version__)' 2>/dev/null || echo 'not installed')"
fi
'';

in
{
# ── Dev shells ──────────────────────────────────────────────────

devShells = {

# CPU (default) ───────────────────────────────────────────────
default = pkgs.mkShell {
name = "pkasolver-cpu";
packages = [ python pkgs.uv pkgs.git ] ++ sharedLibDeps;
shellHook = commonHook;
};

cpu = pkgs.mkShell {
name = "pkasolver-cpu";
packages = [ python pkgs.uv pkgs.git ] ++ sharedLibDeps;
shellHook = commonHook;
};

# CUDA ────────────────────────────────────────────────────────
# Nix provides the CUDA toolkit; PyTorch CUDA wheels come via uv.
# Uses a separate pkgs instance with allowUnfree = true since the
# CUDA toolkit is licensed under the CUDA EULA.
cuda =
let cudaPkgs = pkgs-unfree.cudaPackages; in
pkgs.mkShell {
name = "pkasolver-cuda";
packages = [
python pkgs.uv pkgs.git
cudaPkgs.cudatoolkit
cudaPkgs.cudnn
] ++ sharedLibDeps;
shellHook = commonHook + ''
export CUDA_HOME="${cudaPkgs.cudatoolkit}"
echo "CUDA shell active. To install CUDA-enabled PyTorch:"
echo " uv pip install torch --index-url https://download.pytorch.org/whl/cu121"
'';
};

# ROCm ────────────────────────────────────────────────────────
# Nix provides ROCm runtime; PyTorch ROCm wheels come via uv.
rocm = pkgs.mkShell {
name = "pkasolver-rocm";
packages = [
python pkgs.uv pkgs.git
pkgs.rocmPackages.rocm-runtime
pkgs.rocmPackages.rocm-smi
] ++ sharedLibDeps;
shellHook = commonHook + ''
echo "ROCm shell active. To install ROCm-enabled PyTorch:"
echo " uv pip install torch --index-url https://download.pytorch.org/whl/rocm6.0"
'';
};
};

# ── Formatter ───────────────────────────────────────────────────
formatter = pkgs.nixpkgs-fmt;
};
};
}
31 changes: 16 additions & 15 deletions pkasolver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
"""
pkasolver
toolset for predicting the pka values of small molecules
Toolkit for predicting microstate pKa values of small molecules
via Graph Isomorphism Networks (GINs).
"""

# Add imports here
from .pkasolver import *
from .dimorphite_dl.dimorphite_dl import run_with_mol_list
import logging

# Handle versioneer
from ._version import get_versions
from .dimorphite_dl.dimorphite_dl import run_with_mol_list # noqa: F401

versions = get_versions()
__version__ = versions["version"]
__git_revision__ = versions["full-revisionid"]
del get_versions, versions
try:
from ._version import __version__
except ImportError:
# Package not installed (e.g. running from source without build)
__version__ = "0.0.0+unknown"

import logging
__all__ = [
"__version__",
"run_with_mol_list",
]

# format logging message
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)1s()] %(message)s"
# set logging level
logging.basicConfig(format=FORMAT, datefmt="%d-%m-%Y:%H:%M", level=logging.INFO)
# Configure logging
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)s()] %(message)s"
logging.basicConfig(format=FORMAT, datefmt="%d-%m-%Y:%H:%M", level=logging.WARNING)
Loading