Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 OpenTinker
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
World Model SFT Loss β€” trains the policy to also predict observation tokens.

Joint loss = ppo_loss(action tokens) + wm_coeff * sft_loss(observation tokens)

Implementation:
The WM SFT loss is computed inside dp_actor.update_policy() (verl modification).
It is gated by config.world_model_coeff > 0 AND observation_mask being present
in the batch.

observation_mask is computed in http_training_server.py before update_actor:
obs_mask = attention_mask[:, -resp_len:] & ~response_mask

To enable, set in your config yaml:
actor_rollout_ref:
actor:
world_model_coeff: 0.1
"""

import torch


def compute_observation_mask(batch) -> torch.Tensor:
"""Compute observation_mask from attention_mask and response_mask.

observation tokens = real tokens in the response portion that are NOT
action (LLM-generated) tokens.

Args:
batch: DataProto with batch["attention_mask"] and batch["response_mask"]

Returns:
observation_mask: (batch_size, response_length) float tensor
"""
resp_len = batch.batch["response_mask"].shape[1]
attn_response = batch.batch["attention_mask"][:, -resp_len:]
return (attn_response.bool() & ~batch.batch["response_mask"].bool()).float()
2 changes: 1 addition & 1 deletion opentinker/client/client_config/alfworld_param.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ interaction:
env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port}
# If you run the ALFWorld env server in sharded mode (--shards N),
# set env_shards=N. The client will route each instance_id to a stable shard.
env_shards: 32
env_shards: 8
max_steps: 20 # ALFWorld episodes max steps
max_total_steps: 20 # Max environment step calls (controls rollout turns)
observation_template: "{observation}"
Expand Down
32 changes: 31 additions & 1 deletion opentinker/client/utils/http_training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,32 @@ def set_config(self, args: DictConfig, env=None):
}
)

# Optional world model SFT coefficient for joint PPO + WM training.
world_model_coeff = args.get("world_model_coeff", None)
if world_model_coeff is not None:
server_cfg = OmegaConf.merge(
server_cfg,
OmegaConf.create(
{"algorithm": {"world_model_coeff": float(world_model_coeff)}}
),
)
print(
f"[ServiceClient] Forwarding algorithm.world_model_coeff={world_model_coeff}"
)

# Optional WM loss sparsification ratio
wm_loss_top_ratio = args.get("wm_loss_top_ratio", None)
if wm_loss_top_ratio is not None:
server_cfg = OmegaConf.merge(
server_cfg,
OmegaConf.create(
{"algorithm": {"wm_loss_top_ratio": float(wm_loss_top_ratio)}}
),
)
print(
f"[ServiceClient] Forwarding algorithm.wm_loss_top_ratio={wm_loss_top_ratio}"
)

# Add multi_turn config if present in args
if hasattr(args, "multi_turn") and args.multi_turn:
multi_turn_cfg = OmegaConf.to_container(args.multi_turn, resolve=True)
Expand All @@ -644,7 +670,11 @@ def set_config(self, args: DictConfig, env=None):
server_cfg = OmegaConf.merge(
server_cfg,
OmegaConf.create(
{"actor_rollout_ref": {"rollout": {"agent": {"num_workers": agent_num_workers}}}}
{
"actor_rollout_ref": {
"rollout": {"agent": {"num_workers": agent_num_workers}}
}
}
),
)
print(
Expand Down
1 change: 1 addition & 0 deletions opentinker/environment/alfworld/alfworld_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def step(self, action: str) -> StepResult:
# Note: Don't include "step" here as gym_environment_interaction.py
# already passes it explicitly to observation_template.format()
"raw_reward": float(reward),
"raw_obs": obs, # raw env feedback (before _format_observation)
"action_taken": parsed_action,
"task": self._task_desc,
"won": won_flag,
Expand Down
33 changes: 3 additions & 30 deletions opentinker/scheduler/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def check_gpu_available(gpu_id: int) -> bool:
return True # Fail open

# Thresholds for considering a GPU "idle"
MAX_MEMORY_MB = 10 # Allow up to 100 MB (some baseline CUDA overhead)
MAX_MEMORY_MB = 2000 # Allow up to 2 GB (root processes may use ~1 GB)
MAX_UTILIZATION = 1000 # Allow up to 5% utilization

if memory_used_mb > MAX_MEMORY_MB or utilization_percent > MAX_UTILIZATION:
Expand All @@ -294,35 +294,8 @@ def check_gpu_available(gpu_id: int) -> bool:
)
return False

# Check 2: Look for running processes on this GPU
pmon_result = subprocess.run(
["nvidia-smi", "pmon", "-c", "1", "-s", "um"],
capture_output=True,
text=True,
timeout=5,
)

if pmon_result.returncode == 0:
# Parse pmon output to check for processes on this GPU
# Format: "# gpu pid type sm mem enc dec command"
# " 0 12345 C 50 500 0 0 python"
lines = pmon_result.stdout.strip().split("\n")
for line in lines:
if line.startswith("#") or not line.strip():
continue
parts = line.split()
if len(parts) >= 2:
try:
gpu_idx = int(parts[0].strip())
if gpu_idx == gpu_id and parts[1].strip() != "-":
# Found a process on this GPU
pid = parts[1].strip()
logger.warning(
f"GPU {gpu_id}: ⚠️ OCCUPIED - Process {pid} detected via pmon"
)
return False
except (ValueError, IndexError):
continue
# pmon check disabled β€” root/system processes cause false positives
# when sharing GPUs across users. Memory threshold check above is sufficient.

# All checks passed - GPU is idle
logger.debug(
Expand Down
83 changes: 82 additions & 1 deletion opentinker/server/generic_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(
self.prompt_ids: list[int] = []
self.response_ids: list[int] = []
self.response_mask: list[int] = []
self.observation_mask: list[
int
] = [] # 1 for pure env feedback tokens, 0 otherwise
self.response_logprobs: list[float] = []

# Turn tracking
Expand Down Expand Up @@ -224,7 +227,9 @@ def init_class(cls, config, tokenizer, processor, **kwargs):
# Create per-job subdirectory to isolate traces from different client tasks
job_id = os.environ.get("ROLLOUT_TRACE_JOB_ID", None)
if job_id:
cls._trace_output_dir = str(Path(cls._trace_output_dir) / f"job_{job_id}")
cls._trace_output_dir = str(
Path(cls._trace_output_dir) / f"job_{job_id}"
)
cls._save_traces = True
cls._process_id = os.getpid() # Store process ID for unique trace naming
Path(cls._trace_output_dir).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -465,6 +470,9 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
# Ensure env_info exists for all samples (even if empty) for consistent DataProto.concat
output.extra_fields["env_info"] = agent_data.extra_fields.get("env_info", [])
output.extra_fields["turn_scores"] = agent_data.turn_scores
output.extra_fields["observation_mask"] = agent_data.observation_mask[
: self.response_length
]
# Add any other extra fields (except the ones we already set)
for key, value in agent_data.extra_fields.items():
if key not in output.extra_fields:
Expand Down Expand Up @@ -581,6 +589,9 @@ async def _handle_generating_state(
agent_data.response_mask += [1] * len(
agent_data.response_ids
) # mask=1 for LLM tokens
agent_data.observation_mask += [0] * len(
agent_data.response_ids
) # observation_mask=0 for action tokens

if response_log_probs:
agent_data.response_logprobs += response_log_probs
Expand Down Expand Up @@ -684,6 +695,14 @@ async def _handle_interacting_state(
agent_data.prompt_ids += response_ids
agent_data.response_mask += [0] * len(response_ids)

# Build observation_mask: 1 only for pure env feedback tokens
# (excludes chat template, "=== Current State ===", "=== Available Actions ===" etc.)
raw_obs = info.get("raw_obs", None) if info else None
obs_mask = self._build_env_feedback_mask(
response_ids, observation, raw_obs, self.tokenizer
)
agent_data.observation_mask += obs_mask

if agent_data.response_logprobs:
# Pad logprobs with 0.0 for observation tokens
agent_data.response_logprobs += [0.0] * len(response_ids)
Expand All @@ -693,6 +712,68 @@ async def _handle_interacting_state(
else:
return GenericAgentState.GENERATING

@staticmethod
def _build_env_feedback_mask(
response_ids: list[int],
observation: str,
raw_obs: str | None,
tokenizer,
) -> list[int]:
"""Build a per-token mask marking only pure environment feedback tokens.

Uses character-offset mapping to avoid BPE boundary mismatch:
1. Tokenize the full observation string with offset_mapping
2. Find raw_obs character range within observation
3. Map character range β†’ token indices in observation encoding
4. Find observation tokens as contiguous subsequence in response_ids
(response_ids = chat_template_prefix + obs_tokens + chat_template_suffix)
5. Transfer the per-token mask to response_ids positions

Falls back to all-1 mask if any step fails.
"""
n = len(response_ids)
if not raw_obs or not observation:
return [1] * n

# Step 1-2: find raw_obs character range in the formatted observation
char_start = observation.find(raw_obs)
if char_start < 0:
return [1] * n
char_end = char_start + len(raw_obs)

# Step 3: tokenize observation with offset mapping
try:
enc = tokenizer(
observation, add_special_tokens=False, return_offsets_mapping=True
)
obs_ids = enc["input_ids"]
offsets = enc["offset_mapping"]
except Exception:
# Tokenizer doesn't support offset_mapping; fall back
return [1] * n

# Mark which observation-level tokens overlap with raw_obs char range
obs_level_mask = []
for s, e in offsets:
if e > char_start and s < char_end:
obs_level_mask.append(1)
else:
obs_level_mask.append(0)

# Step 4: find obs_ids as contiguous subsequence in response_ids
# (chat template special tokens don't merge with content tokens)
m = len(obs_ids)
for i in range(n - m + 1):
if response_ids[i : i + m] == obs_ids:
# Step 5: transfer mask
mask = [0] * n
for j in range(m):
mask[i + j] = obs_level_mask[j]
return mask

# obs_ids not found in response_ids β€” fall back
return [1] * n

async def _save_debug_images(self, image_data: list, request_id: str):
"""Save debug images to disk when SAVE_DEBUG_IMAGES env var is set.

Expand Down
Loading
Loading