Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions examples/cfd/isotropic_eddyformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# EddyFormer for 3D Isotropic Turbulence

This example demonstrates how to use the EddyFormer model for simulating
a three-dimensional isotropic turbulence. This example runs on a single GPU.

## Problem Overview

This example focuses on **three-dimensional homogeneous isotropic turbulence (HIT)** sustained by large-scale forcing. The flow is governed by the incompressible Navier–Stokes equations with an external forcing term:

\[
\frac{\partial \mathbf{u}}{\partial t} + \mathbf{u} \cdot \nabla \mathbf{u}
= \nu \nabla^2 \mathbf{u} + \mathbf{f}(\mathbf{x})
\]

where:

- **\(\mathbf{u}(\mathbf{x}, t)\)** — velocity field in a 3D periodic domain
- **\(\nu = 0.01\)** — kinematic viscosity
- **\(\mathbf{f}(\mathbf{x})\)** — isotropic forcing applied at the largest scales

### Forcing Mechanism

To maintain statistically steady turbulence, a **constant-power forcing** is applied to the lowest Fourier modes (\(|\mathbf{k}| \le 1\)). The forcing injects a prescribed amount of energy \(P_{\text{in}} = 1.0\) into the system:

\[
\mathbf{f}(\mathbf{x}) =
\frac{P_{\text{in}}}{E_1}
\sum_{\substack{|\mathbf{k}| \le 1 \\ \mathbf{k} \neq 0}}
\hat{\mathbf{u}}_{\mathbf{k}} e^{i \mathbf{k} \cdot \mathbf{x}}
\]

where:

\[
E_1 = \frac{1}{2}
\sum_{|\mathbf{k}| \le 1}
\hat{\mathbf{u}}_{\mathbf{k}} \cdot \hat{\mathbf{u}}_{\mathbf{k}}^{*}
\]

is the kinetic energy contained in the forced low-wavenumber modes.

Under this forcing, the flow reaches a **statistically steady state** with a Taylor-scale Reynolds number of:

**\(\mathrm{Re}_\lambda \approx 94\)**

### Task Description

The objective of this example is to **predict the future velocity field** of the turbulent flow. Given \(\mathbf{u}(\mathbf{x}, t)\), the task is:

> **Predict the velocity field \(\mathbf{u}(\mathbf{x}, t + \Delta t)\) with \(\Delta t = 0.5\).**

This requires modeling nonlinear, chaotic, multi-scale turbulent dynamics, including:

- energy injection at large scales
- nonlinear transfer across the inertial range
- dissipation at the smallest scales

### Dataset Summary

- **DNS resolution:** \(384^3\) (used to generate the dataset)
- **Stored dataset resolution:** \(96^3\)
- **Kolmogorov scale resolution:** ~0.5 η
- **Forcing:** applied to modes with \(|\mathbf{k}| \le 1\)
- **Viscosity:** \(\nu = 0.01\)
- **Input power:** \(P_{\text{in}} = 1.0\)
- **Flow regime:** statistically steady HIT at \(\mathrm{Re}_\lambda \approx 94\)

## Prerequisites

Install the required dependencies by running below:

```bash
pip install -r requirements.txt
```

## Download the Dataset

The dataset is publicly available at [Huggingface](https://huggingface.co/datasets/ydu11/re94).
To download the dataset, run (you might need to install the Huggingface CLI):

```bash
bash download_dataset.sh
```

## Getting Started

To train the model, run

```bash
python train_ef_isotropic.py
```

## References

- [EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173)
29 changes: 29 additions & 0 deletions examples/cfd/isotropic_eddyformer/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
model:
idim: 3
odim: 3
hdim: 32
num_layers: 4
use_scale: true
layer_config:
basis: leg_elem
mesh: [8, 8, 8]
mode: [13, 13, 13]
mode_les: [5, 5, 5]
kernel_size: [2, 2, 2]
kernel_size_les: [2, 2, 2]
ffn_dim: 128
activation: GELU
num_heads: 4
heads_dim: 32

training:
dataset: data/ns3d-re94
result_dir: outputs/ef-leg-re94
t: 0.5
amp: false
compile: true
batch_size: 4
num_epochs: 1
learning_rate: 1e-3
test_every: 100
ckpt_every: 1000
1 change: 1 addition & 0 deletions examples/cfd/isotropic_eddyformer/download_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94}
2 changes: 2 additions & 0 deletions examples/cfd/isotropic_eddyformer/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
hydra-core>=1.2.0
termcolor>=2.1.1
200 changes: 200 additions & 0 deletions examples/cfd/isotropic_eddyformer/train_ef_isotropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import hydra
from tqdm import tqdm

from typing import Tuple
from torch import Tensor
from omegaconf import DictConfig

import os
import collections
import numpy as np

import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel

from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig
from physicsnemo.distributed import DistributedManager
from physicsnemo.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad
from physicsnemo.launch.utils import save_checkpoint
from physicsnemo.launch.logging import PythonLogger, LaunchLogger


def rel_l2(pred: Tensor, target: Tensor) -> Tensor:
return torch.linalg.norm(pred - target) / torch.linalg.norm(target)

class Re94(Dataset):

root: str
t: float

n: int = 50
dt: float = 0.1

def __init__(self, root: str, split: str, *, t: float = 0.5,
n: int = 50, dt: float = 0.1) -> None:
"""
"""
super().__init__()
self.root = root
self.t = t

self.n = n
self.dt = dt

self.file = []
for fname in sorted(os.listdir(root)):
if fname.startswith(split):
self.file.append(fname)

@property
def stride(self) -> int:
k = int(self.t / self.dt)
assert self.dt * k == self.t
return k

@property
def samples_per_file(self) -> int:
return self.n - self.stride + 1

def __len__(self) -> int:
return len(self.file) * self.samples_per_file

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
file_idx, time_idx = divmod(idx, self.samples_per_file)

data = np.load(f"{self.root}/{self.file[file_idx]}", allow_pickle=True).item()
return torch.from_numpy(data["u"][time_idx]), torch.from_numpy(data["u"][time_idx + self.stride])

def metric(self, pred: Tensor, target: Tensor) -> dict[str, float]:
"""
"""
l2 = [rel_l2(pred[..., i], target[..., i]).item() for i in range(3)]
return { f"err_{ax}": value for ax, value in (zip("xyz", l2)) }

@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml")
def isotropic_trainer(cfg: DictConfig) -> None:
"""
"""
DistributedManager.initialize() # Only call this once in the entire script!
dist = DistributedManager() # call if required elsewhere

# initialize monitoring
log = PythonLogger(name="re94_ef")
log.file_logging(f"{cfg.training.result_dir}/log.txt")
LaunchLogger.initialize() # PhysicsNeMo launch logger

# define model and optimizer
model = EddyFormer(
idim=cfg.model.idim,
odim=cfg.model.odim,
hdim=cfg.model.hdim,
num_layers=cfg.model.num_layers,
use_scale=cfg.model.use_scale,
cfg=EddyFormerConfig(
basis=cfg.model.layer_config.basis,
mesh=tuple(cfg.model.layer_config.mesh),
mode=tuple(cfg.model.layer_config.mode),
mode_les=tuple(cfg.model.layer_config.mode_les),
kernel_size=tuple(cfg.model.layer_config.kernel_size),
kernel_size_les=tuple(cfg.model.layer_config.kernel_size_les),
ffn_dim=cfg.model.layer_config.ffn_dim,
activation=cfg.model.layer_config.activation,
num_heads=cfg.model.layer_config.num_heads,
heads_dim=cfg.model.layer_config.heads_dim,
),
).to(dist.device)

if dist.distributed:
ddps = torch.cuda.Stream()
with torch.cuda.stream(ddps):
model = DistributedDataParallel(
model,
device_ids=[dist.local_rank],
output_device=dist.device,
broadcast_buffers=dist.broadcast_buffers,
find_unused_parameters=dist.find_unused_parameters,
)
torch.cuda.current_stream().wait_stream(ddps)
log.success("Initialized DDP training")

optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate)

# define dataset and dataloader
dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t)
dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True)

testset = Re94(root=cfg.training.dataset, split="test", t=cfg.training.t, n=40, dt=0.5)
testloader = DataLoader(testset, batch_size=None)

# define training step
@StaticCaptureTraining(
model=model,
optim=optimizer,
logger=log,
use_graphs=False,
use_amp=cfg.training.amp,
compile=cfg.training.compile
)
def training_step(input: Tensor, target: Tensor) -> Tensor:
pred = torch.vmap(model)(input)
loss = torch.vmap(rel_l2)(pred, target)
return torch.mean(loss)

# define evaluation step
@StaticCaptureEvaluateNoGrad(
model=model,
logger=log,
use_graphs=False,
use_amp=cfg.training.amp,
compile=cfg.training.compile
)
def forward_eval(input):
return model(input)

it = 0

model.train()
log.info("Training started")

for epoch in range(cfg.training.num_epochs):
for it, (input, target) in enumerate(dataloader, it):

input = input.to(dist.device)
target = target.to(dist.device)
loss = training_step(input, target)

with LaunchLogger("train", epoch=epoch) as logger:
logger.log_minibatch({"Training loss": loss.item()})

if it and it % cfg.training.ckpt_every == 0 and dist.rank == 0:
save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer, epoch=it)

if it and it % cfg.training.test_every == 0:

model.eval()
metrics = collections.defaultdict(float)

for input, target in tqdm(testloader, desc="Test"):

input = input.to(dist.device)
target = target.to(dist.device)

pred = forward_eval(input)
metric = testset.metric(pred, target)

for key, value in metric.items():
metrics[key] += value / len(testset)

with LaunchLogger("test", epoch=epoch) as logger:
logger.log_minibatch(metrics)

model.train()

log.success("Training completed")
save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer)


if __name__ == "__main__":
isotropic_trainer()
3 changes: 3 additions & 0 deletions physicsnemo/models/eddyformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._basis import Legendre
from ._datatype import SEM
from .eddyformer import EddyFormer, EddyFormerConfig
Loading