Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ Common flags:
Required environment (see `src/lib/config.py`):
- `PSC_PLOT_DATA_DIR` — directory containing the data files (must be set; `set_data_dir.sh <dir>` is a convenience script that exports it)
- `PSC_PLOT_FFMPEG_BIN` — optional, falls back to `which ffmpeg`; needed for saving animations
- `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to 1
- `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files
- `PSC_PLOT_DASK_NUM_WORKERS` — optional, defaults to `os.cpu_count()`
- `PSC_PLOT_DASK_CHUNK_SIZE` — optional, rows per dask partition for particle loads (default 1_000_000); reduce to bound peak memory on large files
- `PSC_PLOT_DASK_SCHEDULER` — optional; if set to `"processes"`, uses dask's processes scheduler; if `"distributed"`, spins up a `dask.distributed.LocalCluster` with `n_workers=dask_num_workers, threads_per_worker=1, processes=True`. Unset = dask default (threads).

`PSC_PLOT_DATA_DIR` is read at module-import time (`CONFIG = PscPlotConfig._load()` in `src/lib/config.py`), so it must be set in the environment before any `lib.*` import. In tests, `tests/conftest.py` sets it before importing `lib`.

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"xrft>=1.0",
"pandas>=2.0",
"dask>=2024.0",
"distributed>=2024.0",
"h5py>=3.0",
"tables>=3.10",
"scipy>=1.10",
Expand Down
7 changes: 7 additions & 0 deletions src/lib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def _resolve_save_format(args: Args) -> SaveFormat | None:

def main():
dask.config.set(num_workers=CONFIG.dask_num_workers)
if CONFIG.dask_scheduler == "distributed":
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=CONFIG.dask_num_workers, threads_per_worker=1, processes=True)
Client(cluster)
elif CONFIG.dask_scheduler:
dask.config.set(scheduler=CONFIG.dask_scheduler)

args = parsing.get_parsed_args()
# resolve format BEFORE applying pipeline in order to fail early
Expand Down
11 changes: 6 additions & 5 deletions src/lib/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import shutil
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Self
Expand All @@ -9,6 +8,7 @@
_FFMPEG_BIN_KEY = "PSC_PLOT_FFMPEG_BIN"
_DASK_NUM_WORKERS_KEY = "PSC_PLOT_DASK_NUM_WORKERS"
_DASK_CHUNK_SIZE_KEY = "PSC_PLOT_DASK_CHUNK_SIZE"
_DASK_SCHEDULER_KEY = "PSC_PLOT_DASK_SCHEDULER"


def parse_optional[T](s: str | None, parser: Callable[[str], T]) -> T | None:
Expand All @@ -23,6 +23,7 @@ class PscPlotConfig:
ffmpeg_bin: Path | None
dask_num_workers: int
dask_chunk_size: int
dask_scheduler: str | None

@classmethod
def _load(cls) -> Self:
Expand All @@ -35,15 +36,15 @@ def _load(cls) -> Self:

dask_num_workers = parse_optional(os.environ.get(_DASK_NUM_WORKERS_KEY), int)
if not dask_num_workers:
dask_num_workers = 1
message = f"Number of dask workers not specified; defaulting to {dask_num_workers}. Set {_DASK_NUM_WORKERS_KEY} to specify."
warnings.warn(message)
dask_num_workers = os.cpu_count() or 1

dask_chunk_size = parse_optional(os.environ.get(_DASK_CHUNK_SIZE_KEY), int)
if not dask_chunk_size:
dask_chunk_size = 1_000_000

return cls(data_dir, ffmpeg_bin, dask_num_workers, dask_chunk_size)
dask_scheduler = os.environ.get(_DASK_SCHEDULER_KEY) or None

return cls(data_dir, ffmpeg_bin, dask_num_workers, dask_chunk_size, dask_scheduler)


CONFIG = PscPlotConfig._load()
7 changes: 6 additions & 1 deletion src/lib/data/loaders/field_bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def _get_path(prefix: str, step: int) -> Path:
return CONFIG.data_dir / f"{prefix}.{step:09}.bp"


def _decode_psc(ds):
return pscpy.decode_psc(ds, ["e", "i"])


@loader
class FieldLoaderBp(Loader):
@classmethod
Expand All @@ -34,7 +38,8 @@ def get_data(self) -> Field:
paths=[_get_path(self.prefix, step) for step in self.steps],
combine="nested",
concat_dim="t",
preprocess=lambda ds: pscpy.decode_psc(ds, ["e", "i"]),
preprocess=_decode_psc,
parallel=True,
)
if self.active_key is not None:
derive_field_variable(ds, self.active_key, self.prefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pandas as pd

from lib import var_info_registry
from lib.data.data_with_attrs import List

__all__ = ["derived_particle_variable", "derive_particle_variable", "DERIVED_PARTICLE_VARIABLES"]
Expand All @@ -25,7 +26,10 @@ def __init__(

def assign_to(self, data: List) -> List:
df = data.data
return data.assign_data(df.assign(**{self.name: self.derive(*(df[base_var_name] for base_var_name in self.base_var_names))}))

info = var_info_registry.lookup("prt", self.name)
new_var_infos = {**data.metadata.var_infos, self.name: info}
return data.assign_data(df.assign(**{self.name: self.derive(*(df[base_var_name] for base_var_name in self.base_var_names))})).assign_metadata(var_infos=new_var_infos)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(({', '.join(self.base_var_names)}) -> {self.name}: {self.derive!r})"
Expand Down
Binary file added tests/baseline/test_hamscan.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def test_static_scatter_bp():
return make_plot("prt.e -i t=-1 -v y z time= --grid y=0.0625 z=0.0625".split(), data_dir="test-3d")


@pytest.mark.mpl_image_compare(**MPL_KWARGS)
def test_hamscan():
"""Archetypal scan for hammerhead distributions ("hams")."""
return make_plot("prt.e -i t=-1 --derive pzx --bin y py=20 pzx=20 t= -v py pzx time=y --compute".split(), data_dir="test-2d")


# --- Particle moments ---


Expand Down
Loading