Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions examples/cfd/vortex_shedding_mgn/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path
from typing import Tuple

from pydantic import BaseModel
from typing import Tuple, Optional


class Constants(BaseModel):
"""vortex shedding constants"""

# wb configs
wandb_mode: str = "online"
watch_model: bool = True

# Model name
model_name: str = "training"

# data configs
data_dir: str = "./raw_dataset/cylinder_flow/cylinder_flow"
data_dir: str = "/datasets/cylinder_flow/cylinder_flow"

# training configs
batch_size: int = 1
epochs: int = 25
num_training_samples: int = 400
num_training_time_steps: int = 300
epochs: int = 5 # 25
training_batch_size: int = 11
num_training_samples: int = 22 # 1000
num_training_time_steps: int = 100 # 600
training_noise_std: float = 0.02

valid_batch_size: int = 1 # Must be 1 for now
num_valid_samples: int = 4 # 100
num_valid_time_steps: int = 200 # 600

lr: float = 0.0001
lr_decay_rate: float = 0.9999991
ckpt_path: str = "/workspace/checkpoints_training_6"
ckpt_name: str = "model.pt"

# Mesh Graph Net Setup
num_input_features: int = 6
num_output_features: int = 3
num_edge_features: int = 3
ckpt_path: str = "checkpoints"
ckpt_name: str = "model.pt"
num_output_features: int = 3
processor_size: int = 15
num_layers_node_processor: int = 2
num_layers_edge_processor: int = 2
hidden_dim_processor: int = 128
hidden_dim_node_encoder: int = 128
num_layers_node_encoder: int = 2
hidden_dim_edge_encoder: int = 128
num_layers_edge_encoder: int = 2
hidden_dim_node_decoder: int = 128
num_layers_node_decoder: int = 2
aggregation: str = "sum"
do_concat_trick: bool = False
num_processor_checkpoint_segments: int = 0
activation_fn: str = "elu"

# performance configs
amp: bool = False
jit: bool = False

# test & visualization configs
num_test_samples: int = 10
num_test_time_steps: int = 300
num_test_samples: int = 100
num_test_time_steps: int = 600
viz_vars: Tuple[str, ...] = ("u", "v", "p")
frame_skip: int = 10
frame_interval: int = 1

# wb configs
wandb_mode: str = "disabled"
watch_model: bool = False
Loading