Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ on:
branches: [main]

jobs:
# NOTE: ubuntu-latest is the only supported runner. torch-harmonics 0.9
# only ships manylinux_x86_64 wheels — no macOS, no Windows, no Linux
# ARM. Adding macos-14 / windows-latest to this matrix will fail at the
# `uv pip install` step. See RELEASE.md "Supported runtime matrix" for
# the full ABI rationale. Revisit when torch-harmonics ships broader
# wheels OR when torch-harmonics is moved to an extras_require.
test:
runs-on: ubuntu-latest
strategy:
Expand Down Expand Up @@ -42,3 +48,33 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: geometric-intelligence/bispectrum

typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Install dependencies
run: uv pip install --system -e ".[dev]"

- name: Type-check with mypy
run: mypy src/bispectrum

lockfile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4

- name: Verify uv.lock is up to date
run: uv lock --check
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ dmypy.json
Thumbs.db

# uv
uv.lock
# uv.lock is tracked — see [tests.yml] lockfile job.

# LaTeX build artifacts
*.aux
Expand Down
11 changes: 2 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ repos:
args: [--fix, --unsafe-fixes, --exit-non-zero-on-fix, --ignore=T201]
- id: ruff-format
types: [python]
args: [--line-length=99]

# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
Expand Down Expand Up @@ -67,8 +66,8 @@ repos:
args: ["-s", "B101"]

# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
- repo: https://github.com/rbubley/mirrors-prettier
rev: v3.6.2
hooks:
- id: prettier
types: [yaml]
Expand Down Expand Up @@ -101,9 +100,3 @@ repos:
args:
- --skip=logs/**,data/**,*.ipynb
- --ignore-words=.codespell-ignore-words

# jupyter notebook cell output clearing
- repo: https://github.com/kynan/nbstripout
rev: 0.8.2
hooks:
- id: nbstripout
50 changes: 39 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,30 @@ The G-bispectrum is a principled *complete* invariant of a signal — it retains

## Supported Groups

| Module | Group / Domain | Output mode | Complexity (selective) |
| -------------- | --------------------------------- | ---------------- | ---------------------- |
| `CnonCn` | C_n on C_n | selective + full | O(n) |
| `SO2onS1` | SO(2) on S¹ | selective + full | O(n) |
| `TorusOnTorus` | T^d | selective + full | O(n) |
| `DnonDn` | D_n on D_n | selective | O(n) |
| `SO2onDisk` | SO(2) on disk | selective | O(L) |
| `SO3onS2` | SO(3) on S² | selective + full | Θ(L²) |
| `OctaonOcta` | chiral octahedral O (24 elements) | selective | 172 coefficients |
| Module | Group / Domain | Output mode | Complexity (selective) | Inversion |
| -------------- | --------------------------------- | ---------------- | ---------------------- | --------- |
| `CnonCn` | C_n on C_n | selective + full | O(n) | yes |
| `SO2onS1` | SO(2) on S¹ | selective + full | O(n) | yes |
| `TorusOnTorus` | T^d | selective + full | O(n) | yes |
| `DnonDn` | D_n on D_n | selective | O(n) | yes |
| `SO2onDisk` | SO(2) on disk | selective | O(L) | yes |
| `SO3onS2` | SO(3) on S² | selective + full | Θ(L²) | no\* |
| `OctaonOcta` | chiral octahedral O (24 elements) | selective | 172 coefficients | yes |

`SO2onS1` is the continuous-n limit of `CnonCn` and shares its implementation.
\*SO(3)-on-S² inversion is an open mathematical problem; calling `invert(beta)`
on `SO3onS2` raises `NotImplementedError`.

## API

Every module exposes a uniform interface:
Every module exposes the same surface:

- **`forward(f)`** — selective (default) or full bispectral invariants
- **`fourier(f)`** — group Fourier coefficients
- **`invert(beta)`** — signal reconstruction up to group-action indeterminacy (where available)
- **`invert(beta)`** — signal reconstruction up to group-action indeterminacy.
Available where the table above says "yes"; check programmatically via the
class attribute `Module.supports_inversion: bool`. Selective-bispectrum
inversion always requires `selective=True`.

Modules default to O(|G|) selective coefficients; pass `selective=False` for the full O(|G|²) set. CG matrices, DFT kernels, and Bessel roots are precomputed as non-learnable buffers. Dependencies: PyTorch, NumPy, and `torch_harmonics` (for `SO3onS2`).

Expand All @@ -47,6 +52,29 @@ Median wall-clock on a single NVIDIA H100 80 GB GPU (batch=16, `torch.utils.benc
| `SO3onS2` | SO(3) | L=16 | 430 | — | 0.48 | — |
| `OctaonOcta` | O | 24 | 172 | — | 0.68 | — |

## Compatibility

- **Platform**: Linux x86_64 only. `torch_harmonics` 0.9 ships only
`manylinux_x86_64` wheels — no macOS, no Windows, no Linux ARM. The
rest of the package is platform-independent, but the install will fail
on anything else.
- **Python**: 3.12 only. `torch_harmonics` 0.9 ships no cp310/cp311/cp313
wheel either, so even though the bispectrum source itself is compatible
with 3.10+, the install will fail outside of 3.12.
- **PyTorch**: `>=2.10`. `torch_harmonics` 0.9 links against
`c10::TensorImpl::decref_pyobject`, a symbol that first appeared in
torch 2.10. Older torch raises `ImportError` at
`import torch_harmonics`.
- **CUDA**: drivers `<12.8` need to override the torch wheel:
```bash
pip install bispectrum --extra-index-url https://download.pytorch.org/whl/cu128
```
- **Cache directory**: precomputed CG and Bessel tables are persisted to
`~/.cache/bispectrum/`. Set `BISPECTRUM_CACHE_DIR=/path/to/cache` to
override (useful on read-only home directories).

See [`RELEASE.md`](RELEASE.md#supported-runtime-matrix) for the full ABI rationale.

## Installation

```bash
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/bench_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ def _write_latex_table(stats: dict, path: Path) -> None:
lines.append(r'\midrule')

for thresh_str, rate in stats['success_rates'].items():
lines.append(
rf'Success rate (error {thresh_str}) & \multicolumn{{2}}{{c}}{{{rate:.1%}}} \\'
)
lines.append(rf'Success rate (error {thresh_str}) & \multicolumn{{2}}{{c}}{{{rate:.1%}}} \\')

lines.extend(
[
Expand Down Expand Up @@ -353,9 +351,7 @@ def _plot_error_histogram(errors: np.ndarray, path: Path) -> None:
def main():
parser = argparse.ArgumentParser(description='Octahedral inversion benchmark')
parser.add_argument('--n_signals', type=int, default=1000)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--n_corrections', type=int, default=10)
parser.add_argument('--n_restarts', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=64)
Expand Down
5 changes: 1 addition & 4 deletions benchmarks/bench_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ def bench_forward_paths() -> None:
if gr is not None:
t_graph = f'{_time_fn(lambda bsp=bsp, fv=f, b=batch, e=entries: bsp._forward_cuda_graph(fv, b, e)):.3f}'

_print(
f'{lmax:>5d} {entries:>7d} {t_sparse:>10.3f} {t_graph:>10s} '
f'{t_init:>10.2f} {sp_mb:>8.1f}'
)
_print(f'{lmax:>5d} {entries:>7d} {t_sparse:>10.3f} {t_graph:>10s} {t_init:>10.2f} {sp_mb:>8.1f}')

del bsp, f
if device.type == 'cuda':
Expand Down
28 changes: 7 additions & 21 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,12 @@ def bench_forward_pass(device: torch.device) -> plt.Figure:
f = _make_input(module_name, params, batch, device)
g = _group_order(module_name, params)

t = _time_fn(
'bsp(f)', {'bsp': bsp, 'f': f}, f'{module_name}_{g}_{sel_label}', device
)
t = _time_fn('bsp(f)', {'bsp': bsp, 'f': f}, f'{module_name}_{g}_{sel_label}', device)
group_orders.append(g)
times_ms.append(t)

if group_orders:
ax.plot(
group_orders, times_ms, marker=mk, linestyle=ls, label=sel_label, linewidth=2
)
ax.plot(group_orders, times_ms, marker=mk, linestyle=ls, label=sel_label, linewidth=2)
print(f'\n{module_name} {sel_label} ({dev_name}):')
print(f' {"|G|":>10s} {"time_ms":>10s}')
for g, t in zip(group_orders, times_ms, strict=False):
Expand Down Expand Up @@ -644,9 +640,7 @@ def paper_figures(device: torch.device) -> None:
all_gs_flat.extend(gs)

bsp_octa = OctaonOcta(selective=True)
ax1.scatter(
[24], [bsp_octa.output_size], marker='*', s=80, color=c[5], zorder=5, label=r'$O$ sel.'
)
ax1.scatter([24], [bsp_octa.output_size], marker='*', s=80, color=c[5], zorder=5, label=r'$O$ sel.')

x_lo, x_hi = 4, max(all_gs_flat) * 2
xs = [x_lo, x_hi]
Expand All @@ -660,19 +654,15 @@ def paper_figures(device: torch.device) -> None:
linewidth=0,
label=r'full $O(|G|^2)$ region',
)
ax1.plot(
xs, [x**2 for x in xs], color='0.55', linestyle='-.', linewidth=0.8, label=r'$O(|G|^2)$'
)
ax1.plot(xs, [x**2 for x in xs], color='0.55', linestyle='-.', linewidth=0.8, label=r'$O(|G|^2)$')

ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_xlabel(r'$|G|$')
ylab = r'$\#$ coefficients' if plt.rcParams.get('text.usetex') else '# coefficients'
ax1.set_ylabel(ylab)
ax1.set_title('Bispectral coefficient count: selective vs. full')
ax1.legend(
fontsize=7, ncol=3, columnspacing=1.0, handlelength=1.5, loc='upper left', framealpha=0.9
)
ax1.legend(fontsize=7, ncol=3, columnspacing=1.0, handlelength=1.5, loc='upper left', framealpha=0.9)
ax1.set_ylim(bottom=5)
_grid(ax1)

Expand Down Expand Up @@ -781,9 +771,7 @@ def paper_figures(device: torch.device) -> None:
break
throughputs.append(bs / (t / 1e3))
valid_batches.append(bs)
ax_d.plot(
valid_batches, throughputs, marker=marker, color=color, label=rf'{label} ($|G|$={g})'
)
ax_d.plot(valid_batches, throughputs, marker=marker, color=color, label=rf'{label} ($|G|$={g})')
print(f' GPU scaling {mod_name}: max throughput = {max(throughputs):.0f} samples/s')

ax_d.set_xscale('log')
Expand Down Expand Up @@ -818,9 +806,7 @@ def paper_figures(device: torch.device) -> None:
g = _group_order(mod_name, params)
with torch.no_grad():
beta = bsp(f)
t = _time_fn(
'bsp.invert(beta)', {'bsp': bsp, 'beta': beta}, f'{mod_name}_inv_{g}', device
)
t = _time_fn('bsp.invert(beta)', {'bsp': bsp, 'beta': beta}, f'{mod_name}_inv_{g}', device)
gs.append(g)
inv_times.append(t)
ax_e.plot(gs, inv_times, marker=marker, color=color, label=label)
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/verify_linear_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,8 @@ def verify_bootstrap(lmax: int) -> bool:


def main() -> None:
parser = argparse.ArgumentParser(
description='Verify linear bootstrap generic full-rank condition'
)
parser.add_argument(
'--lmax', type=int, default=100, help='Maximum degree to verify (default: 100)'
)
parser = argparse.ArgumentParser(description='Verify linear bootstrap generic full-rank condition')
parser.add_argument('--lmax', type=int, default=100, help='Maximum degree to verify (default: 100)')
args = parser.parse_args()

print(f'Verifying linear bootstrap for ell = 4 .. {args.lmax}')
Expand Down
18 changes: 5 additions & 13 deletions benchmarks/verify_seed_fibre.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,7 @@ def main() -> None:
targets_all_parts.append(cg_power_fast(Fs_w, *trip, cg_cache))
targets_all = np.array(targets_all_parts)

print(
f'\nTarget values ({len(targets_bisp_d3)} bisp + {len(targets_cgp_d3)} CGP at degree 3):'
)
print(f'\nTarget values ({len(targets_bisp_d3)} bisp + {len(targets_cgp_d3)} CGP at degree 3):')
for t, v in zip(BISP_D3_TRIPLES, targets_bisp_d3, strict=False):
print(f' Re β{t} = {v:.10f}')
for t, v in zip(CGP_D3_TRIPLES, targets_cgp_d3, strict=False):
Expand Down Expand Up @@ -268,15 +266,12 @@ def main() -> None:
if is_new:
solutions.append(found.copy())
hit_counts.append(1)
print(
f' Start {i:5d}: NEW solution #{len(solutions)} (cost={result.cost:.2e})'
)
print(f' Start {i:5d}: NEW solution #{len(solutions)} (cost={result.cost:.2e})')

if (i + 1) % 1000 == 0:
elapsed = time.time() - t0
print(
f' ... {i + 1}/{args.num_starts} '
f'({elapsed:.1f}s, {len(solutions)} solutions)',
f' ... {i + 1}/{args.num_starts} ({elapsed:.1f}s, {len(solutions)} solutions)',
flush=True,
)

Expand Down Expand Up @@ -328,15 +323,12 @@ def main() -> None:
if is_new:
full_solutions.append(found.copy())
full_hit_counts.append(1)
print(
f' Start {i:5d}: NEW solution #{len(full_solutions)} (cost={result.cost:.2e})'
)
print(f' Start {i:5d}: NEW solution #{len(full_solutions)} (cost={result.cost:.2e})')

if (i + 1) % 1000 == 0:
elapsed = time.time() - t0
print(
f' ... {i + 1}/{args.num_starts} '
f'({elapsed:.1f}s, {len(full_solutions)} solutions)',
f' ... {i + 1}/{args.num_starts} ({elapsed:.1f}s, {len(full_solutions)} solutions)',
flush=True,
)

Expand Down
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,30 @@ testpaths = ["tests"]
python_version = "3.12"
strict = true

# torch_harmonics has no py.typed marker; treat as untyped.
[[tool.mypy.overrides]]
module = "torch_harmonics.*"
ignore_missing_imports = true

# nn.Module.register_buffer registers attributes as Tensor | Module in the
# torch stubs, which forces a wall of false-positive index/operator/arg-type
# errors at every buffer access. torch.special / torch.* return Any in the
# stubs, triggering no-any-return everywhere. Disable those codes per-module
# rather than papering the source with type: ignore comments.
[[tool.mypy.overrides]]
module = [
"bispectrum.so3_on_s2",
"bispectrum.so2_on_disk",
"bispectrum.octa_on_octa",
"bispectrum.dn_on_dn",
"bispectrum.torus_on_torus",
"bispectrum.cn_on_cn",
"bispectrum.rotation",
"bispectrum._bessel",
"bispectrum._cg",
]
disable_error_code = ["index", "operator", "arg-type", "no-untyped-call", "no-any-return"]

[tool.coverage.run]
source = ["src/bispectrum"]

Expand Down
8 changes: 4 additions & 4 deletions src/bispectrum/_bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def bessel_jn(n: int, x: torch.Tensor) -> torch.Tensor:
def _jn_scalar(n: int, x: float) -> float:
"""Fast scalar evaluation of J_n(x) using raw math."""
if n == 0:
return torch.special.bessel_j0(torch.tensor(x, dtype=torch.float64)).item()
return float(torch.special.bessel_j0(torch.tensor(x, dtype=torch.float64)).item())
if n == 1:
return torch.special.bessel_j1(torch.tensor(x, dtype=torch.float64)).item()
return float(torch.special.bessel_j1(torch.tensor(x, dtype=torch.float64)).item())

xt = torch.tensor(x, dtype=torch.float64)
j_prev = torch.special.bessel_j0(xt).item()
j_curr = torch.special.bessel_j1(xt).item()
j_prev = float(torch.special.bessel_j0(xt).item())
j_curr = float(torch.special.bessel_j1(xt).item())

if x == 0:
return 0.0
Expand Down
Loading
Loading