Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 152 additions & 70 deletions opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def compute_s_star(
if count > 0:
p_k = torch.exp(log_p)
s_token = p_k * (H + log_p)
# Handle potential NaN due to 0 * -inf when p_k is 0 and log_p is -inf
s_token = torch.nan_to_num(s_token, nan=0.0)
s_t = (s_token * mask).sum() / count
else:
s_t = torch.tensor(0.0, device=device)
Expand Down Expand Up @@ -155,60 +157,67 @@ def compute_h_wm(
def compute_dynamic_mask(
s_star_per_sample: List[List[torch.Tensor]],
h_wm_per_sample: List[List[torch.Tensor]],
mu_base: float = 1.0,
lambda_wm: float = 1.0,
mu_base: float,
mu_exp: float,
eta_wm: float,
lambda_wm: float,
s_bar: float,
sigma: float,
clipping_method: str = "mask",
) -> List[List[float]]:
"""Compute per-turn dynamic entropy clipping mask.

m_t = 1 if |S_*^t - S_bar| <= mu_base * (1 + lambda * H_WM^t) * sigma
0 otherwise
"""Compute per-turn dynamic entropy clipping mask or coefficient.

Logic:
- WM confident (H_WM→0): threshold tightens → overconfident policy blocked
(prevents overfitting in well-understood regions)
- WM uncertain (H_WM large): threshold widens → exploration encouraged
(allows agent to gather data in unknown regions to train the WM)

H_WM is detached — it only acts as a scalar gate, never participates in
policy gradient backpropagation.
- WM uncertainty signal: f(H_WM) = eta_wm * exp(-lambda_wm * H_WM)
- Threshold: threshold = mu * f(H_WM) * sigma
- Masking: m_t = 0.0 if violation, 1.0 otherwise
- Clipping: m_t = threshold / deviation if violation, 1.0 otherwise (PPO-style)

Args:
s_star_per_sample: per-sample, per-turn S_* tensors
h_wm_per_sample: per-sample, per-turn H_WM tensors
mu_base: base clipping coefficient
lambda_wm: WM uncertainty weight
mu_base: clipping coefficient for collapsing side
mu_exp: clipping coefficient for exploration side
eta_wm: base multiplier for WM uncertainty signal
lambda_wm: exponential decay factor for WM uncertainty
s_bar: mean of S_* (batch or global)
sigma: std of S_* (batch or global)
clipping_method: "mask" or "clip"

Returns:
List of lists of floats (0.0 or 1.0), one mask per turn per sample.
List of lists of floats, one mask/coeff per turn per sample.
"""
# Flatten all S_* for batch statistics
all_s = []
for turns in s_star_per_sample:
for s in turns:
all_s.append(s.detach())

if len(all_s) == 0:
return [[] for _ in s_star_per_sample]

all_s_tensor = torch.stack(all_s)
s_bar = all_s_tensor.mean()

# Guard for single-element: std is 0, threshold = mu_base * (1 + lambda * h_wm) * 0
# → everything would be masked. Use 1.0 as default sigma for single element.
if len(all_s) <= 1:
sigma = torch.tensor(1.0, device=all_s_tensor.device)
else:
sigma = all_s_tensor.std(unbiased=False) + 1e-8

mask_per_sample = []
for i in range(len(s_star_per_sample)):
masks = []
for t in range(len(s_star_per_sample[i])):
s_t = s_star_per_sample[i][t].detach()
h_t = h_wm_per_sample[i][t].detach()

threshold = mu_base * (1.0 + lambda_wm * h_t) * sigma
m_t = 1.0 if torch.abs(s_t - s_bar) <= threshold else 0.0
s_t = s_star_per_sample[i][t].detach().item()
h_t = h_wm_per_sample[i][t].detach().item()

# WM uncertainty signal: f(H_WM) = eta_wm * exp(-lambda_wm * H_WM)
h_factor = eta_wm * np.exp(-lambda_wm * h_t)

# Asymmetric threshold calculation
if s_t > s_bar:
# Collapsing side
threshold = mu_base * h_factor * sigma
diff = s_t - s_bar
if clipping_method == "mask":
m_t = 1.0 if diff <= threshold else 0.0
else: # PPO-style clipping
# If diff > threshold, we scale the advantage by threshold/diff
# such that the effective update is capped at threshold
u_t = min(1.0, diff / (threshold + 1e-8))
m_t = 1.0 / (1.0 + 0.5 * u_t)
else:
# Exploration side
threshold = mu_exp * h_factor * sigma
diff = s_bar - s_t
if clipping_method == "mask":
m_t = 1.0 if diff <= threshold else 0.0
else: # PPO-style clipping
m_t = 1.0

masks.append(m_t)
mask_per_sample.append(masks)

Expand All @@ -219,77 +228,150 @@ def apply_wmc_erc(
batch,
entropys: torch.Tensor,
wmc_erc_config,
running_stats: Dict[str, float],
) -> Tuple[object, Dict[str, float]]:
"""Apply WMC-ERC dynamic entropy clipping to batch advantages.

Pipeline:
1. Compute turn boundaries from response_mask
2. Compute S_* (policy blind confidence) per turn
3. Compute H_WM (world model uncertainty) per turn from env token entropys
4. Compute dynamic mask m_t per turn
5. Apply mask to advantages: A_masked = A * m_t (broadcast to tokens)
6. Return metrics for logging

Args:
batch: DataProto or compatible object with batch dict containing
advantages, response_mask, old_log_probs, attention_mask
entropys: (batch_size, response_length) stored before pop in train_step
wmc_erc_config: OmegaConf DictConfig or dict with mu_base, lambda_wm, enable
batch: DataProto or compatible object
entropys: (batch_size, response_length)
wmc_erc_config: OmegaConf DictConfig or dict
running_stats: Dictionary for global running statistics

Returns:
(batch, metrics) where batch has masked advantages and metrics dict
(batch, metrics)
"""
enable = wmc_erc_config.get("enable", True) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'enable', True)
if not enable:
return batch, {}

clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch")
clipping_method = wmc_erc_config.get("clipping_method", "mask") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_method', "mask")
clip_positive_only = wmc_erc_config.get("clip_positive_only", False) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clip_positive_only', False)
inverse_sft_mask = wmc_erc_config.get("inverse_sft_mask", False) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'inverse_sft_mask', False)

response_mask = batch.batch["response_mask"]
old_log_probs = batch.batch["old_log_probs"]
advantages = batch.batch["advantages"]

# Compute attention mask for response region
response_length = advantages.shape[1]
attention_mask = batch.batch["attention_mask"]
attention_mask_response = attention_mask[:, -response_length:]

# 1. Turn boundaries
turn_boundaries = compute_turn_boundaries(response_mask)

# 2. S_* per turn
# 2. Compute S_* and H_WM per turn
s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries)

# 3. H_WM per turn
h_wm = compute_h_wm(entropys, response_mask, attention_mask_response, turn_boundaries)

# 4. Dynamic mask
# Calculate batch statistics
all_s = [s.item() for turns in s_star for s in turns]
all_h = [h.item() for turns in h_wm for h in turns]

if not all_s:
return batch, {}

batch_s_bar = np.mean(all_s)
batch_s_std = np.std(all_s) + 1e-8
batch_h_bar = np.mean(all_h) + 1e-8

# Update global statistics
momentum = wmc_erc_config.get("momentum", 0.9) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'momentum', 0.9)
if len(running_stats.keys()) == 0:
running_stats["s_bar"] = batch_s_bar
running_stats["s_std"] = batch_s_std
running_stats["h_bar"] = batch_h_bar
else:
running_stats["s_bar"] = (1 - momentum) * batch_s_bar + momentum * running_stats["s_bar"]
running_stats["s_std"] = (1 - momentum) * batch_s_std + momentum * running_stats["s_std"]
running_stats["h_bar"] = (1 - momentum) * batch_h_bar + momentum * running_stats["h_bar"]

# Select statistics for masking
if clipping_type == "global":
use_s_bar = running_stats["s_bar"]
use_s_std = running_stats["s_std"]
use_h_bar = running_stats["h_bar"]
else:
use_s_bar = batch_s_bar
use_s_std = batch_s_std
use_h_bar = batch_h_bar

# 4. Dynamic mask/clip
mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0))
mu_exp = float(wmc_erc_config.get("mu_exp", 2.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_exp', 2.0))
eta_wm = float(wmc_erc_config.get("eta_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'eta_wm', 1.0))
lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0))
mask = compute_dynamic_mask(s_star, h_wm, mu_base, lambda_wm)

# 5. Apply mask to advantages (in-place)

mask = compute_dynamic_mask(
s_star, h_wm, mu_base, mu_exp, eta_wm, lambda_wm,
s_bar=use_s_bar,
sigma=use_s_std,
clipping_method=clipping_method
)

# 5. Apply mask/coeff to advantages
batch_size = advantages.shape[0]

if inverse_sft_mask:
sft_weights = torch.zeros_like(advantages)
env_mask = attention_mask_response * (1.0 - response_mask)

for i in range(batch_size):
for t, (start, end) in enumerate(turn_boundaries[i]):
if t < len(mask[i]):
advantages[i, start:end] *= mask[i][t]
m_t = mask[i][t]

if inverse_sft_mask:
# Env tokens after this turn: [end, next_turn_start) or [end, seq_len)
if t + 1 < len(turn_boundaries[i]):
env_end = turn_boundaries[i][t + 1][0]
else:
env_end = response_length

sft_weight = 1.0 - m_t
region_mask = env_mask[i, end:env_end]
sft_weights[i, end:env_end] = region_mask * sft_weight

if m_t < 1.0:
if clip_positive_only:
# Only apply scaling where advantages > 0
turn_adv = advantages[i, start:end]
advantages[i, start:end] = torch.where(turn_adv > 0, turn_adv * m_t, turn_adv)
else:
advantages[i, start:end] *= m_t
batch.batch["advantages"] = advantages

if inverse_sft_mask:
batch.batch["sft_weights"] = sft_weights

# 6. Metrics
all_s = [s.item() for turns in s_star for s in turns]
all_h = [h.item() for turns in h_wm for h in turns]
all_m = [m for turns in mask for m in turns]

num_collapsing_violated = 0
num_exploration_violated = 0
for i in range(len(s_star)):
for t in range(len(s_star[i])):
if mask[i][t] < 1.0:
if s_star[i][t].item() > use_s_bar:
num_collapsing_violated += 1
else:
num_exploration_violated += 1

# WM NLL (monitoring only — not in backward pass for this prototype)
env_mask = attention_mask_response * (1.0 - response_mask)
env_count = env_mask.sum()
wm_nll = (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() if env_count > 0 else 0.0

metrics = {
"wmc_erc/s_star_mean": float(np.mean(all_s)) if all_s else 0.0,
"wmc_erc/s_star_std": float(np.std(all_s)) if all_s else 0.0,
"wmc_erc/h_wm_mean": float(np.mean(all_h)) if all_h else 0.0,
"wmc_erc/batch_s_bar": float(batch_s_bar),
"wmc_erc/batch_s_std": float(batch_s_std),
"wmc_erc/batch_h_bar": float(batch_h_bar),
"wmc_erc/running_s_bar": float(running_stats["s_bar"]),
"wmc_erc/running_s_std": float(running_stats["s_std"]),
"wmc_erc/running_h_bar": float(running_stats["h_bar"]),
"wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0,
"wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0),
"wmc_erc/num_violated_turns": sum(1 for m in all_m if m < 1.0),
"wmc_erc/num_collapsing_violated": num_collapsing_violated,
"wmc_erc/num_exploration_violated": num_exploration_violated,
"wmc_erc/total_turns": len(all_m),
"wmc_erc/wm_nll": wm_nll,
}
Expand Down
50 changes: 50 additions & 0 deletions opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 OpenTinker
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
World Model SFT Loss — trains the policy to also predict observation tokens.

Joint loss = ppo_loss(action tokens) + wm_coeff * sft_loss(observation tokens)

Implementation:
The WM SFT loss is computed inside dp_actor.update_policy() (verl modification).
It is gated by config.world_model_coeff > 0 AND observation_mask being present
in the batch.

observation_mask is computed in http_training_server.py before update_actor:
obs_mask = attention_mask[:, -resp_len:] & ~response_mask

To enable, set in your config yaml:
actor_rollout_ref:
actor:
world_model_coeff: 0.1
"""

import torch


def compute_observation_mask(batch) -> torch.Tensor:
"""Compute observation_mask from attention_mask and response_mask.

observation tokens = real tokens in the response portion that are NOT
action (LLM-generated) tokens.

Args:
batch: DataProto with batch["attention_mask"] and batch["response_mask"]

Returns:
observation_mask: (batch_size, response_length) float tensor
"""
resp_len = batch.batch["response_mask"].shape[1]
attn_response = batch.batch["attention_mask"][:, -resp_len:]
return (attn_response.bool() & ~batch.batch["response_mask"].bool()).float()
Loading