Skip to content

Conversation

@wdyab
Copy link

@wdyab wdyab commented Nov 26, 2025

PhysicsNeMo Pull Request

Description

This PR adds a comprehensive deep learning framework for neural operator learning in reservoir simulation to PhysicsNemo. Implements multiple architectures (U-FNO, Conv-FNO, UNet) with physics-informed losses, distributed training, and extensive experiment tracking capabilities.

Closes #1255

Checklist

  • [ x] I am familiar with the Contributing Guidelines.
  • [ x] New or existing tests cover these changes.
  • [ x] The documentation is up to date with these changes.
  • [ x] The CHANGELOG.md is up to date with these changes.
  • [ x] An #1255 is linked to this pull request.

Dependencies

This contribution uses existing PhysicsNemo dependencies and adds no new requirements beyond what's already in the main repository. All dependencies are standard PyTorch ecosystem packages:

  • PyTorch (core framework)
  • Hydra (configuration management) - already in PhysicsNemo
  • MLFlow (experiment tracking) - optional, commonly available
  • TensorBoard (visualization) - standard PyTorch tool

The example is self-contained in examples/reservoir_simulation/DeepONet/ and includes its own requirements.txt for reference.

Testing

  • Trained and validated on production datasets (500 realizations)
  • Multi-GPU training verified (DDP across 8 GPUs)
  • All architectures tested (pressure + saturation)
  • Pre-commit hooks passing
  • Compatible with PhysicsNemo utilities (DistributedManager, checkpoint utils, logging)

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

This contribution adds a comprehensive deep learning framework for CO2
sequestration reservoir simulation, featuring:

- Multiple neural operator architectures (U-FNO, Conv-FNO, Standalone UNet)
- Flexible configuration system for model selection and hyperparameters
- Automatic model checkpoint naming to prevent overwriting
- Dynamic data validation utilities
- Comprehensive evaluation scripts with multiple metrics
- Apache 2.0 licensed, pre-commit compliant

The DeepONet framework enables efficient spatiotemporal prediction of
pressure and saturation fields in CO2 sequestration scenarios.

Location: examples/reservoir_simulation/DeepONet/

Signed-off-by: wdyab <wdyab@nvidia.com>
Add entry for DeepONet framework contribution to version 1.4.0a0.

References: NVIDIA#1255

Signed-off-by: wdyab <wdyab@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 26, 2025

Greptile Overview

Greptile Summary

This PR adds a comprehensive deep learning framework for neural operator learning in reservoir simulation (CO2 sequestration modeling). The implementation includes multiple architectures (U-FNO, Conv-FNO, UNet), physics-informed losses, distributed training support, and integration with PhysicsNeMo utilities.

Key Changes:

  • Added ~4,600 lines of code across 17 source files in examples/reservoir_simulation/DeepONet/
  • Implements U-FNO architecture using PhysicsNeMo's SpectralConv3d, UNet, and other core modules
  • Provides flexible Hydra-based configuration with support for multiple model variants
  • Includes evaluation scripts with comprehensive metrics (MRE, MPE, MAE, R², relative L2)
  • Supports distributed training via DDP with proper sampler and normalization synchronization

Issues Found:

  • Config path bug: train_fno3d.py:343 references cfg.arch.activation_fn which doesn't exist in the config structure (should be cfg.arch.ufno.activation_fn)
  • UNet4D uses non-existent PyTorch ops (nn.Conv4d, nn.BatchNorm4d) - already flagged in previous threads
  • Grid spacing formula incorrect in losses.py:266 - already flagged in previous threads
  • Hardcoded user path in config - already flagged in previous threads

Minor Items:

  • Uses deprecated torch.cuda.amp.autocast import (should use torch.amp.autocast)
  • Bare except: clauses in dataset.py (style issue)
  • torch.load calls without weights_only=True (PyTorch security recommendation)

Important Files Changed

File Analysis

Filename Score Overview
examples/reservoir_simulation/DeepONet/train_fno3d.py 3/5 Main training script with bug: references cfg.arch.activation_fn which doesn't exist in config (should be cfg.arch.ufno.activation_fn). Uses deprecated torch.cuda.amp.autocast import.
examples/reservoir_simulation/DeepONet/losses.py 2/5 Loss functions with incorrect grid spacing formula in _extract_grid_spacing - doesn't compute central finite difference spacing correctly (already flagged in previous review threads).
examples/reservoir_simulation/DeepONet/unet3d.py 2/5 UNet4D class uses non-existent PyTorch operations (nn.Conv4d, nn.BatchNorm4d, nn.ConvTranspose4d) - will fail at runtime if used (already flagged in previous review threads).
examples/reservoir_simulation/DeepONet/conf/training_config.yaml 3/5 Contains hardcoded user-specific data path /home/wdyab/physicsnemo/data_lustre (already flagged). Otherwise clean configuration.
examples/reservoir_simulation/DeepONet/ufno.py 4/5 Well-structured U-FNO/Conv-FNO implementation using PhysicsNemo modules. Clean architecture with proper padding handling.
examples/reservoir_simulation/DeepONet/dataset.py 4/5 Clean dataset implementation with proper distributed data loading support. Uses bare except: clauses (style issue) but functionally correct.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

18 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 372 to 535
class UNet4D(nn.Module):
"""4D U-Net for higher-dimensional spatiotemporal data (H × W × T × D)."""

def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: int = 3,
dropout_rate: float = 0.0,
):
super().__init__()

self.input_channels = input_channels
self.output_channels = output_channels
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate

# Encoder
self.conv1 = self._conv_block(
input_channels,
output_channels,
kernel_size=kernel_size,
stride=2,
dropout_rate=dropout_rate,
)
self.conv2 = self._conv_block(
input_channels,
output_channels,
kernel_size=kernel_size,
stride=2,
dropout_rate=dropout_rate,
)
self.conv2_1 = self._conv_block(
input_channels,
output_channels,
kernel_size=kernel_size,
stride=1,
dropout_rate=dropout_rate,
)
self.conv3 = self._conv_block(
input_channels,
output_channels,
kernel_size=kernel_size,
stride=2,
dropout_rate=dropout_rate,
)
self.conv3_1 = self._conv_block(
input_channels,
output_channels,
kernel_size=kernel_size,
stride=1,
dropout_rate=dropout_rate,
)

# Decoder
self.deconv2 = self._deconv_block(input_channels, output_channels)
self.deconv1 = self._deconv_block(input_channels * 2, output_channels)
self.deconv0 = self._deconv_block(input_channels * 2, output_channels)

# Output
self.output_layer = self._output_block(
input_channels * 2,
output_channels,
kernel_size=kernel_size,
stride=1,
dropout_rate=dropout_rate,
)

def _conv_block(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
dropout_rate: float,
) -> nn.Module:
"""4D convolutional block."""
return nn.Sequential(
nn.Conv4d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
bias=False,
),
nn.BatchNorm4d(out_channels),
nn.LeakyReLU(0.1, inplace=True),
nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(),
)

def _deconv_block(self, in_channels: int, out_channels: int) -> nn.Module:
"""4D transposed convolutional block."""
return nn.Sequential(
nn.ConvTranspose4d(
in_channels, out_channels, kernel_size=4, stride=2, padding=1
),
nn.LeakyReLU(0.1, inplace=True),
)

def _output_block(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
dropout_rate: float,
) -> nn.Module:
"""Output layer."""
return nn.Conv4d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass. Input: (batch, channels, H, W, T, D)"""
# Validate dimensions (must be divisible by 8)
dims = x.shape[2:]
if any(d % 8 != 0 for d in dims):
raise ValueError(
f"Input dimensions {dims} must be divisible by 8. Got shape: {x.shape}"
)

# Encoder
out_conv1 = self.conv1(x)
out_conv2 = self.conv2_1(self.conv2(out_conv1))
out_conv3 = self.conv3_1(self.conv3(out_conv2))

# Decoder with skip connections
out_deconv2 = self.deconv2(out_conv3)
if out_deconv2.shape[2:] != out_conv2.shape[2:]:
out_deconv2 = F.interpolate(
out_deconv2, size=out_conv2.shape[2:], mode="nearest"
)
concat2 = torch.cat((out_conv2, out_deconv2), dim=1)

out_deconv1 = self.deconv1(concat2)
if out_deconv1.shape[2:] != out_conv1.shape[2:]:
out_deconv1 = F.interpolate(
out_deconv1, size=out_conv1.shape[2:], mode="nearest"
)
concat1 = torch.cat((out_conv1, out_deconv1), dim=1)

out_deconv0 = self.deconv0(concat1)
if out_deconv0.shape[2:] != x.shape[2:]:
out_deconv0 = F.interpolate(out_deconv0, size=x.shape[2:], mode="nearest")
concat0 = torch.cat((x, out_deconv0), dim=1)

out = self.output_layer(concat0)

return out

def count_params(self) -> int:
"""Count total number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Aliases for backward compatibility with U-FNO
UNetModule2D = UNet2D
UNetModule3D = UNet3D
UNetModule4D = UNet4D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: PyTorch doesn't have nn.Conv4d, nn.ConvTranspose4d, or nn.BatchNorm4d operations. This entire UNet4D class will fail at runtime.

If 4D convolutions are needed, consider using a workaround like reshaping to 3D, or remove this class if it's not currently being used (the PR description mentions 2D+T support is completed, 3D+T is partial).

Comment on lines +265 to +266
grid_x = inputs[0, 0, :, 0, -3] # (W=200,) - extract grid_x channel values
grid_dx = grid_x[1:-1] + grid_x[:-2] / 2 + grid_x[2:] / 2 # (W-2=198,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: incorrect grid spacing formula - adds terms instead of computing differences

The current formula grid_dx = grid_x[1:-1] + grid_x[:-2] / 2 + grid_x[2:] / 2 doesn't compute spacing correctly. For central finite differences, grid spacing should be the distance between adjacent points.

Should likely be:

Suggested change
grid_x = inputs[0, 0, :, 0, -3] # (W=200,) - extract grid_x channel values
grid_dx = grid_x[1:-1] + grid_x[:-2] / 2 + grid_x[2:] / 2 # (W-2=198,)
grid_x = inputs[0, 0, :, 0, -3] # (W=200,) - extract grid_x channel values
grid_dx = (grid_x[2:] - grid_x[:-2]) / 2 # (W-2=198,) - central difference spacing

This computes the average spacing for central finite differences: (x[i+1] - x[i-1]) / 2.

Comment on lines 32 to 34
data_path: /home/wdyab/physicsnemo/data_lustre
variable: saturation # 'pressure' or 'saturation'
normalize: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: hardcoded user-specific data path

The path /home/wdyab/physicsnemo/data_lustre is specific to the author's system and won't work for other users. Consider using a relative path or adding a clear comment that users must update this.

data:
data_path: /home/wdyab/physicsnemo/data_lustre
variable: saturation # 'pressure' or 'saturation'
normalize: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: normalize: false may not be intended for production use - verify this is intentional for the saturation variable or if it should be true by default

Copy link
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wdyab - thanks for submitting this PR. I will let some others with more experience in reservoir simulation comment on the physics and accuracy for this PR, but I have a few overall comments.

  1. The biggest issue I see with this PR is the implementation of standalone models, here, in the examples. Are these models not possible with the FNO components in physicsnemo? Should they be integrated or perhaps deployed elsewhere? I am not in favor of large models living in examples moving forward - it makes the repository quite challenging to maintain, and they will get deprecated / dropped much faster than the core models.
  2. The AI review points out, and I agree because it's true, that there is no such thing as conv4d in mainstream torch.
  3. The conv4d issue highlights another problem: you're introducing full models here but there are no tests on this code. You also can not add tests to physicsnemo proper that import from examples. We're removing that from other examples and it won't be permitted going forward - it just makes the repository very challenging to maintain.
  4. Your readme lacks sufficient detail. Can you consider adding some motivation, references, example training runs, etc? It looks like it was written by AI - and that's fine to start, if it was - but it still takes a human eye to review and ensure the README has sufficient detail to be useful.

Thanks!

@wdyab wdyab closed this Dec 10, 2025
@wdyab
Copy link
Author

wdyab commented Dec 10, 2025

Hi @coreyjadams, thank you for the detailed review! I've addressed your feedback:

  1. Models in examples:
    I want to clarify that the U-FNO implementation uses PhysicsNeMo's core components:
  • SpectralConv3d from physicsnemo.models.layers
  • ConvNdKernel1Layer from physicsnemo.models.layers
  • UNet from physicsnemo.models.unet
  • FullyConnected from physicsnemo.models.mlp

The example demonstrates how to compose these existing components into a FNO family of neural operator architectures for reservoir simulation. I'm planning to meet with the PhysicsNeMo product owner to discuss the best integration approach and whether any components should be moved to the core library. Additionally, I'm planning to extend the implementation (over the next two weeks) to show how a similar implementation can be achieved for the DeepONet family of neural operators. This should make easy for reservoir simulation engineers to adopt PhysicsNeMo and all the great things it can offer.

  1. UNet4D issue:
    Fixed - I've removed the non-functional UNet4D class that used non-existent PyTorch operations (nn.Conv4d, etc.). It was part of testing code that should not have been pushed in the first place.

  2. Tests:
    Added comprehensive unit tests in tests/ directory covering:
    UNet models (forward pass, gradient flow, parameter counting)
    Loss functions (MSE, L1, Relative L2, masking, derivatives)
    U-FNO model (different configurations, lifting/decoder types, UNet types)
    Note that the implementation was used to reproduce (exactly) all the results in the paper.

  3. README:
    Expanded the README with:

  • Physics background on reservoir simulation and neural operators
  • Architecture diagram
  • Data format documentation
  • Testing instructions
  • Proper citations to the U-FNO paper (Wen et al. 2022)

I've also fixed the hardcoded data path in the configuration.

Note: This contribution is still a work in progress. The framework has been developed and tested on 2D CO2 sequestration datasets, and I plan to continue extending it for reservoir simulation problems.

Please let me know if you'd like any additional changes!

@wdyab wdyab reopened this Dec 10, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. examples/reservoir_simulation/DeepONet/train_fno3d.py, line 343 (link)

    logic: incorrect config path - cfg.arch.activation_fn doesn't exist; should be cfg.arch.ufno.activation_fn

18 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

- Remove non-functional UNet4D class (uses non-existent nn.Conv4d)
- Add comprehensive unit tests in tests/ directory
  - test_unet.py: Tests for UNet2D and UNet3D models
  - test_losses.py: Tests for loss functions (MSE, L1, Relative L2)
  - test_ufno.py: Tests for U-FNO model (requires physicsnemo)
- Expand README with physics background, references, and testing instructions
- Fix hardcoded data path in training_config.yaml
- Generalize documentation for reservoir simulation

Test results: 41 passed, 1 skipped (U-FNO tests skipped when physicsnemo not installed)

Signed-off-by: wdyab <wdyab@nvidia.com>
@wdyab
Copy link
Author

wdyab commented Dec 10, 2025

Good catch! Fixed in commit 2d520ef.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a configurable framework for neural operator learning based on the DeepONet and its variants

2 participants