Skip to content

Latest commit

 

History

History
148 lines (101 loc) · 6.32 KB

File metadata and controls

148 lines (101 loc) · 6.32 KB

Bayesian Nonparametric Graph Pooling

TMLR arXiv tgp

This is a lightweight codebase to reproduce the experiments in the paper BNPool: Bayesian Nonparametric Pooling for Graph Neural Networks by Daniele Castellana and Filippo Maria Bianchi. This repository uses torch-geometric-pool (tgp) directly for all pooling layers, including the official BN-Pool implementation.

BNPool

BNPool is a hierarchical graph pooling layer for graph classification that can also be used for node clustering. It uses a Bayesian non-parametric formulation to adapt the number of clusters to each graph instead of fixing it in advance.

BNPool-training

🛠️ Setup

Conda

Create the environment from the provided file:

conda env create -f environment.yml
conda activate bnpool

The checked-in environment.yml is configured for Linux + NVIDIA CUDA. If you want a CPU/MPS-only Conda environment instead, comment out the two lines marked in environment.yml.

uv

Create and sync the environment with:

uv sync

If uv cannot find a compatible local Python, install one explicitly and retry:

uv python install 3.12
uv sync

Then either activate the virtual environment:

source .venv/bin/activate

or run commands directly through uv:

uv run python minimal_example.py
uv run python run_classification.py
uv run python run_clustering.py

⚡️ Quick start

The file minimal_example.py is a minimal end-to-end example that:

  • loads MUTAG
  • imports BNPool directly from tgp
  • trains a small graph-classification model

Run it with:

python minimal_example.py

🧪 Experiments

Graph classification

Run the default graph-classification configuration:

python run_classification.py

This uses Hydra and defaults to dataset=mutag.

Example override:

python run_classification.py dataset=bench-hard pooler=mincut epochs=100 optimizer.hparams.lr=1e-4

Node clustering

Run the default node-clustering configuration:

python run_clustering.py

This defaults to dataset=community.

Smoke-test configs

Short validation runs for the available setups are provided through:

python run_classification.py --config-name test_classification -m
python run_clustering.py --config-name test_clustering -m

Warning

This might take some time and a few datasets require a GPU with more than 24GB of VRAM.

📂 Project structure

.
├── config/                 # Hydra configs
├── source/
│   ├── data/               # Dataset loading and preprocessing
│   ├── models/             # Model definitions using tgp poolers
│   ├── pl_modules/         # PyTorch Lightning training modules
│   └── utils/              # Hydra, metrics, and training utilities
├── minimal_example.py      # Small BNPool example with tgp
├── run_classification.py   # Graph-classification runner
├── run_clustering.py       # Node-clustering runner
├── environment.yml         # Conda environment
└── LICENSE

📚 Citation

If you use this code, please cite:

@article{castellana2026bnpool,
  title={BNPool: Bayesian Nonparametric Pooling for Graph Neural Networks},
  author={Castellana, Daniele and Bianchi, Filippo Maria},
  journal={Transactions on Machine Learning Research},
  year={2026},
  url={https://openreview.net/forum?id=3B3Zr2xfkf}
}

🔐 License

This project is licensed under the MIT License. See LICENSE.