Skip to content
8 changes: 8 additions & 0 deletions deepmd/dpmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
save_dp_model,
traverse_model_dict,
)
from .training_utils import (
compute_total_numb_batch,
resolve_model_prob,
resolve_model_prob_from_epochs,
)

__all__ = [
"AtomExcludeMask",
Expand All @@ -49,6 +54,7 @@
"aggregate",
"build_multiple_neighbor_list",
"build_neighbor_list",
"compute_total_numb_batch",
"extend_coord_with_ghosts",
"get_graph_index",
"get_multiple_nlist_key",
Expand All @@ -60,6 +66,8 @@
"nlist_distinguish_types",
"normalize_coord",
"phys2inter",
"resolve_model_prob",
"resolve_model_prob_from_epochs",
"save_dp_model",
"to_face_distance",
"traverse_model_dict",
Expand Down
188 changes: 188 additions & 0 deletions deepmd/dpmodel/utils/training_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from collections.abc import (
Iterable,
)

import numpy as np

log = logging.getLogger(__name__)


def compute_total_numb_batch(
numb_batches: Iterable[int],
sampler_weights: np.ndarray,
) -> int:
"""Compute total number of batches considering sampler weights.

Parameters
----------
numb_batches : Iterable[int]
Number of batches for each data system.
sampler_weights : np.ndarray
Sampling weights for each data system.

Returns
-------
int
Total number of batches.

Raises
------
ValueError
If input validation fails.
"""
weights = np.asarray(sampler_weights, dtype=np.float64)
if weights.ndim != 1:
raise ValueError("Sampler weights must be 1D.")
if weights.size == 0:
raise ValueError("Sampler weights are empty.")
if not np.all(np.isfinite(weights)):
raise ValueError("Sampler weights must be finite.")
if np.any(weights < 0.0):
raise ValueError("Sampler weights must be non-negative.")
weight_sum = float(np.sum(weights))
if weight_sum <= 0.0:
raise ValueError("Sampler weights must sum to a positive value.")
probs = weights / weight_sum
nbatches = np.asarray(numb_batches, dtype=np.float64)
if nbatches.ndim != 1:
raise ValueError("Number of batches must be 1D.")
if nbatches.size == 0:
raise ValueError("Number of batches is empty.")
if not np.all(np.isfinite(nbatches)):
raise ValueError("Number of batches must be finite.")
if np.any(nbatches < 0.0):
raise ValueError("Number of batches must be non-negative.")
if nbatches.shape[0] != probs.shape[0]:
raise ValueError("Number of batches and sampler weights must match.")
valid = probs > 0.0
if not np.any(valid):
raise ValueError(
"Sampler probabilities must contain at least one positive entry."
)
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))


def resolve_model_prob(
model_keys: list[str],
model_prob_config: dict[str, float] | None,
model_training_data: dict[str, object],
rank: int = 0,
) -> np.ndarray:
"""Resolve model training probability for multi-task training.

Parameters
----------
model_keys : list[str]
List of model keys.
model_prob_config : dict[str, float] | None
User-specified model probabilities. If None, use data size.
model_training_data : dict[str, object]
Training data for each model.
rank : int, optional
Process rank for distributed training, by default 0.

Returns
-------
np.ndarray
Normalized model probabilities.

Raises
------
ValueError
If input validation fails.
"""
model_prob = np.zeros(len(model_keys), dtype=np.float64)
if model_prob_config:
missing = [k for k in model_keys if k not in model_prob_config]
if missing:
raise ValueError(
f"training.model_prob must specify all tasks; missing: {missing}"
)
for ii, model_key in enumerate(model_keys):
if model_key in model_prob_config:
model_prob[ii] = float(model_prob_config[model_key])
else:
if rank == 0:
log.info(
"training.model_prob is not set or empty; defaulting to the "
"number of systems per task."
)
for ii, model_key in enumerate(model_keys):
model_prob[ii] = float(len(model_training_data[model_key]))
if not np.all(np.isfinite(model_prob)):
raise ValueError("Model prob must be finite.")
if np.any(model_prob < 0.0):
raise ValueError("Model prob must be non-negative.")
sum_prob = float(np.sum(model_prob))
if sum_prob <= 0.0:
raise ValueError("Sum of model prob must be larger than 0!")
return model_prob / sum_prob


def resolve_model_prob_from_epochs(
model_keys: list[str],
num_epoch_dict_config: dict[str, float],
per_task_total: np.ndarray,
) -> tuple[np.ndarray, int, dict[str, float]]:
"""Resolve model probability and training steps from epoch configuration.

Parameters
----------
model_keys : list[str]
List of model keys.
num_epoch_dict_config : dict[str, float]
Target epochs for each task.
per_task_total : np.ndarray
Total batches per task.

Returns
-------
tuple[np.ndarray, int, dict[str, float]]
Model probabilities, total training steps, and per-task steps.

Raises
------
ValueError
If input validation fails.
"""
if not num_epoch_dict_config:
raise ValueError("training.num_epoch_dict must be set for multi-task epochs.")
missing = [k for k in model_keys if k not in num_epoch_dict_config]
if missing:
raise ValueError(
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
)
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
for ii, model_key in enumerate(model_keys):
epoch_value = num_epoch_dict_config[model_key]
if epoch_value is None:
raise ValueError(
f"training.num_epoch_dict['{model_key}'] must be positive."
)
epoch_value = float(epoch_value)
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
raise ValueError(
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
)
epoch_targets[ii] = epoch_value
per_task_total = np.asarray(per_task_total, dtype=np.float64)
if per_task_total.ndim != 1:
raise ValueError("Per-task total batches must be 1D.")
if per_task_total.shape[0] != epoch_targets.shape[0]:
raise ValueError("Per-task totals and epoch targets must match.")
if not np.all(np.isfinite(per_task_total)):
raise ValueError("Per-task total batches must be finite.")
if np.any(per_task_total <= 0.0):
raise ValueError("Per-task total batches must be positive.")
per_task_steps = per_task_total * epoch_targets
total_target_steps = float(np.sum(per_task_steps))
if total_target_steps <= 0.0:
raise ValueError("Sum of target steps must be positive.")
model_prob = per_task_steps / total_target_steps
num_steps = int(np.ceil(total_target_steps))
per_task_steps_map = {
model_key: float(per_task_steps[ii]) for ii, model_key in enumerate(model_keys)
}
return model_prob, num_steps, per_task_steps_map
103 changes: 87 additions & 16 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.dpmodel.utils import (
compute_total_numb_batch,
resolve_model_prob,
resolve_model_prob_from_epochs,
)
from deepmd.dpmodel.utils.learning_rate import (
BaseLR,
)
Expand Down Expand Up @@ -130,9 +135,12 @@ def __init__(
else 1
)
self.num_model = len(self.model_keys)
self.model_prob = None

# Iteration config
self.num_steps = training_params["numb_steps"]
self.num_steps = training_params.get("numb_steps")
self.num_epoch = training_params.get("num_epoch")
self.num_epoch_dict = training_params.get("num_epoch_dict")
self.acc_freq: int = training_params.get(
"acc_freq", 1
) # gradient accumulation steps
Expand Down Expand Up @@ -390,6 +398,75 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
),
)

per_task_total = []
if not self.multi_task:
sampler_weights = to_numpy_array(
self.training_dataloader.batch_sampler.sampler.weights
)
total_numb_batch = compute_total_numb_batch(
training_data.index,
sampler_weights,
)
if self.num_steps is None:
if self.num_epoch is None:
raise ValueError(
"Either training.numb_steps or training.num_epoch must be set."
)
if self.num_epoch <= 0:
raise ValueError("training.num_epoch must be positive.")
if total_numb_batch <= 0:
raise ValueError(
"Total number of training batches must be positive."
)
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
log.info(
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
self.num_steps,
self.num_epoch,
total_numb_batch,
)
else:
for model_key in self.model_keys:
sampler_weights = to_numpy_array(
self.training_dataloader[model_key].batch_sampler.sampler.weights
)
per_task_total.append(
compute_total_numb_batch(
training_data[model_key].index,
sampler_weights,
)
)
if self.num_epoch_dict:
(
self.model_prob,
self.num_steps,
per_task_steps,
) = resolve_model_prob_from_epochs(
self.model_keys,
self.num_epoch_dict,
np.asarray(per_task_total, dtype=np.float64),
)
log.info(
"Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s "
"with per-task target steps: %s.",
self.model_prob,
self.num_steps,
self.num_epoch_dict,
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
)
else:
if self.num_steps is None:
raise ValueError(
"Either training.numb_steps (multi-task only) or "
"training.num_epoch_dict must be set."
)
self.model_prob = resolve_model_prob(
self.model_keys,
training_params.get("model_prob"),
training_data,
rank=self.rank,
)

# Learning rate
self.warmup_steps = training_params.get("warmup_steps", 0)
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
Expand Down Expand Up @@ -575,6 +652,15 @@ def single_model_finetune(
frz_model = paddle.jit.load(init_frz_model)
self.model.set_state_dict(frz_model.state_dict())

# Get model prob for multi-task
if self.multi_task and self.model_prob is None:
self.model_prob = resolve_model_prob(
self.model_keys,
training_params.get("model_prob"),
training_data,
rank=self.rank,
)

# Multi-task share params
if shared_links is not None:
self.wrapper.share_params(
Expand Down Expand Up @@ -682,21 +768,6 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# Get model prob for multi-task
if self.multi_task:
self.model_prob = np.array([0.0 for key in self.model_keys])
if training_params.get("model_prob", None) is not None:
model_prob = training_params["model_prob"]
for ii, model_key in enumerate(self.model_keys):
if model_key in model_prob:
self.model_prob[ii] += float(model_prob[model_key])
else:
for ii, model_key in enumerate(self.model_keys):
self.model_prob[ii] += float(len(self.training_data[model_key]))
sum_prob = np.sum(self.model_prob)
assert sum_prob > 0.0, "Sum of model prob must be larger than 0!"
self.model_prob = self.model_prob / sum_prob

# Tensorboard
self.enable_tensorboard = training_params.get("tensorboard", False)
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")
Expand Down
Loading