Skip to content

[stacked 3/3, multi-datapipes] darcy flow multi dataset#1507

Open
coreyjadams wants to merge 6 commits intoNVIDIA:mainfrom
coreyjadams:darcy-flow-multi-dataset
Open

[stacked 3/3, multi-datapipes] darcy flow multi dataset#1507
coreyjadams wants to merge 6 commits intoNVIDIA:mainfrom
coreyjadams:darcy-flow-multi-dataset

Conversation

@coreyjadams
Copy link
Collaborator

[stacked]

PR 3/3

Will be rebased after first 2 merge, includes final darcy example with multi dataset and runnable example.

  • Add multi datapipes + corresponding tests.
  • Updated tests for updated transforms; Add resize and reshape transforms; add in-memory numpy reader for small single file datasets like darcy flow
  • Add darcy example with multiple datasets
  • Adding most of the multi-dataset cleanly

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@coreyjadams coreyjadams changed the title darcy flow multi dataset [stacked] darcy flow multi dataset Mar 15, 2026
@coreyjadams coreyjadams changed the title [stacked] darcy flow multi dataset [stacked 3/3, multi-datapipes] darcy flow multi dataset Mar 15, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 15, 2026

Greptile Summary

This PR introduces a MultiDataset class that composes multiple Dataset instances behind a single index space, enabling training on heterogeneous data sources through a unified DataLoader. It also adds Resize and Reshape transforms, a preload_to_cpu option for NumpyReader, and a complete Darcy flow multi-dataset training example.

  • MultiDataset (physicsnemo/datapipes/multi_dataset.py): Concatenates datasets with global-to-local index mapping, optional output key validation (output_strict), metadata enrichment with dataset_index, and full prefetch/close delegation. Well-tested with 15+ test cases.
  • Resize transform (physicsnemo/datapipes/transforms/spatial.py): Spatial resizing via F.interpolate for 2D/3D tensors, with automatic channel dimension handling.
  • Reshape transform (physicsnemo/datapipes/transforms/utility.py): Applies torch.reshape to specified TensorDict keys.
  • NumpyReader preload (physicsnemo/datapipes/readers/numpy.py): preload_to_cpu=True loads all arrays into RAM at init and closes the file, eliminating per-sample disk I/O. Also enforces float32 dtype uniformly.
  • Normalize fix (physicsnemo/datapipes/transforms/normalize.py): Uses collections.abc.Mapping instead of dict for isinstance check, fixing compatibility with Hydra's DictConfig.
  • DataLoader.__len__ fix (physicsnemo/datapipes/dataloader.py): Respects sampler length when available, necessary for correct batch count with SubsetRandomSampler in train/val splits.
  • Darcy example (examples/cfd/darcy-multidataset/): End-to-end training with Transolver on numpy + HDF5 PDEBench datasets via Hydra config. RelativeL2Loss has a division-by-zero risk when targets are all-zero (see inline comment).

Important Files Changed

Filename Overview
physicsnemo/datapipes/multi_dataset.py New MultiDataset class composing multiple Dataset instances behind a single index space. Well-structured with proper delegation, output validation, prefetch support, and context manager. Linear search in _index_to_dataset_and_local could use bisect for many datasets but is fine for typical use.
physicsnemo/datapipes/readers/numpy.py Adds preload_to_cpu option for single-file NumpyReader, loading all arrays into RAM at init. Also adds float32 enforcement. Clean implementation with proper close/cleanup behavior. Well tested.
physicsnemo/datapipes/transforms/spatial.py New Resize transform using F.interpolate for 2D/3D tensors. Handles channel and batch dimensions correctly. Supports multiple interpolation modes with proper align_corners handling.
physicsnemo/datapipes/transforms/utility.py New Reshape transform applying torch.reshape to specified keys. Simple and correct implementation following existing patterns (extra_repr, registry). Clones the TensorDict to avoid mutation.
physicsnemo/datapipes/transforms/normalize.py Small but important fix: uses collections.abc.Mapping instead of dict for isinstance check, allowing DictConfig and other Mapping types from Hydra to work correctly.
examples/cfd/darcy-multidataset/train.py Training entrypoint for Darcy Transolver with multi-dataset. RelativeL2Loss has a division-by-zero risk when targets are all zeros. Un-detached pred stored in train_sample holds unnecessary computation graph in memory.
physicsnemo/datapipes/dataloader.py Small fix to len to respect sampler length when available (e.g. SubsetRandomSampler), instead of always using dataset length. Correct and necessary for train/val split workflows.
test/datapipes/core/test_multi_dataset.py Comprehensive tests for MultiDataset: basic indexing, negative indexing, strict validation, prefetch delegation, DataLoader integration, error cases, and repr. Good coverage.
test/datapipes/readers/test_numpy_consolidated.py Thorough tests for NumpyReader including preload_to_cpu, float32 conversion, coordinated subsampling, default values, and memory cleanup. Well structured.

Last reviewed commit: 0fc3987

Comment on lines +43 to +44
ref = torch.norm(target.reshape(B, -1), dim=1)
return torch.mean(diff / ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero in RelativeL2Loss

ref is the per-sample L2 norm of the target. If any sample in the batch has an all-zero (or near-zero) target, ref[i] == 0 and diff[i] / ref[i] produces inf or nan, corrupting the loss and gradients for the entire batch.

Consider clamping the denominator:

Suggested change
ref = torch.norm(target.reshape(B, -1), dim=1)
return torch.mean(diff / ref)
ref = torch.norm(target.reshape(B, -1), dim=1)
return torch.mean(diff / ref.clamp(min=1e-8))

train_l2_err_sq += metrics["l2_err_sq"]
train_l2_ref_sq += metrics["l2_ref_sq"]
train_n += b
train_sample = (x, y, pred)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale computation graph held in memory

train_sample = (x, y, pred) stores a reference to pred which still carries the autograd graph from the current iteration. This keeps the full computation graph alive in memory until the next iteration overwrites it. For large models this doubles the peak GPU memory during training.

Consider detaching before storing:

Suggested change
train_sample = (x, y, pred)
train_sample = (x, y, pred.detach())

Comment on lines +315 to +317
val_loss_sum = _zero.clone()
val_l2_err_sq = _zero.clone()
val_l2_ref_sq = _zero.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation accumulators reuse stale _zero from training

val_loss_sum = _zero.clone() (and the other accumulators) clone the _zero tensor created at the top of the epoch loop (line 277). Since _zero is torch.tensor(0.0, device=dist.device) with requires_grad=False, this works, but _zero was already used for train_loss_sum etc. and those were accumulated into during training. The .clone() call creates a fresh zero so this is technically correct, but it would be clearer to create fresh torch.tensor(0.0, ...) here instead of relying on .clone() of a tensor that has been previously assigned elsewhere.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant