Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 OpenTinker
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
World Model SFT Loss β€” trains the policy to also predict observation tokens.

Joint loss = ppo_loss(action tokens) + wm_coeff * sft_loss(observation tokens)

Implementation:
The WM SFT loss is computed inside dp_actor.update_policy() (verl modification).
It is gated by config.world_model_coeff > 0 AND observation_mask being present
in the batch.

observation_mask is computed in http_training_server.py before update_actor:
obs_mask = attention_mask[:, -resp_len:] & ~response_mask

To enable, set in your config yaml:
actor_rollout_ref:
actor:
world_model_coeff: 0.1
"""

import torch


def compute_observation_mask(batch) -> torch.Tensor:
"""Compute observation_mask from attention_mask and response_mask.

observation tokens = real tokens in the response portion that are NOT
action (LLM-generated) tokens.

Args:
batch: DataProto with batch["attention_mask"] and batch["response_mask"]

Returns:
observation_mask: (batch_size, response_length) float tensor
"""
resp_len = batch.batch["response_mask"].shape[1]
attn_response = batch.batch["attention_mask"][:, -resp_len:]
return (attn_response.bool() & ~batch.batch["response_mask"].bool()).float()
2 changes: 1 addition & 1 deletion opentinker/client/client_config/alfworld_param.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ interaction:
env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port}
# If you run the ALFWorld env server in sharded mode (--shards N),
# set env_shards=N. The client will route each instance_id to a stable shard.
env_shards: 32
env_shards: 8
max_steps: 20 # ALFWorld episodes max steps
max_total_steps: 20 # Max environment step calls (controls rollout turns)
observation_template: "{observation}"
Expand Down
13 changes: 13 additions & 0 deletions opentinker/client/utils/http_training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,19 @@ def set_config(self, args: DictConfig, env=None):
}
)

# Optional world model SFT coefficient for joint PPO + WM training.
world_model_coeff = args.get("world_model_coeff", None)
if world_model_coeff is not None:
server_cfg = OmegaConf.merge(
server_cfg,
OmegaConf.create(
{"algorithm": {"world_model_coeff": float(world_model_coeff)}}
),
)
print(
f"[ServiceClient] Forwarding algorithm.world_model_coeff={world_model_coeff}"
)

# Add multi_turn config if present in args
if hasattr(args, "multi_turn") and args.multi_turn:
multi_turn_cfg = OmegaConf.to_container(args.multi_turn, resolve=True)
Expand Down
33 changes: 3 additions & 30 deletions opentinker/scheduler/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def check_gpu_available(gpu_id: int) -> bool:
return True # Fail open

# Thresholds for considering a GPU "idle"
MAX_MEMORY_MB = 10 # Allow up to 100 MB (some baseline CUDA overhead)
MAX_MEMORY_MB = 2000 # Allow up to 2 GB (root processes may use ~1 GB)
MAX_UTILIZATION = 1000 # Allow up to 5% utilization

if memory_used_mb > MAX_MEMORY_MB or utilization_percent > MAX_UTILIZATION:
Expand All @@ -294,35 +294,8 @@ def check_gpu_available(gpu_id: int) -> bool:
)
return False

# Check 2: Look for running processes on this GPU
pmon_result = subprocess.run(
["nvidia-smi", "pmon", "-c", "1", "-s", "um"],
capture_output=True,
text=True,
timeout=5,
)

if pmon_result.returncode == 0:
# Parse pmon output to check for processes on this GPU
# Format: "# gpu pid type sm mem enc dec command"
# " 0 12345 C 50 500 0 0 python"
lines = pmon_result.stdout.strip().split("\n")
for line in lines:
if line.startswith("#") or not line.strip():
continue
parts = line.split()
if len(parts) >= 2:
try:
gpu_idx = int(parts[0].strip())
if gpu_idx == gpu_id and parts[1].strip() != "-":
# Found a process on this GPU
pid = parts[1].strip()
logger.warning(
f"GPU {gpu_id}: ⚠️ OCCUPIED - Process {pid} detected via pmon"
)
return False
except (ValueError, IndexError):
continue
# pmon check disabled β€” root/system processes cause false positives
# when sharing GPUs across users. Memory threshold check above is sufficient.

# All checks passed - GPU is idle
logger.debug(
Expand Down
42 changes: 41 additions & 1 deletion opentinker/server/http_training_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
)



# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -639,6 +640,7 @@ def __init__(
# Server state
self.is_initialized = False
self.global_steps = 0
self.wm_coeff = 0.0

# Generation config (can be overridden by client)
self.generation_config = {
Expand Down Expand Up @@ -677,6 +679,15 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]:
try:
# optimizer needs parameter: total_steps
self.trainer.post_init(total_steps)

# Forward algorithm.world_model_coeff β†’ actor config so dp_actor can read it
algo_wm_coeff = self.config.algorithm.get("world_model_coeff", 0.0)
if algo_wm_coeff > 0:
from omegaconf import open_dict
with open_dict(self.config):
self.config.actor_rollout_ref.actor.world_model_coeff = algo_wm_coeff
logger.info(f"Forwarded world_model_coeff={algo_wm_coeff} to actor config")

logger.info("Initializing workers...")

# Check async rollout mode
Expand Down Expand Up @@ -707,6 +718,11 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]:
if self.async_rollout_mode:
self.async_rollout_manager = self.trainer.async_rollout_manager

# World model SFT loss coeff (actual loss computed in dp_actor via config.world_model_coeff)
self.wm_coeff = self.config.actor_rollout_ref.actor.get("world_model_coeff", 0.0)
if self.wm_coeff > 0:
logger.info(f"World model SFT loss enabled with coeff={self.wm_coeff}")

self.is_initialized = True
logger.info("Workers initialized successfully")
return {"status": "success", "message": "Workers initialized"}
Expand Down Expand Up @@ -1170,6 +1186,14 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]:
)
metrics.update(critic_output_metrics)

# 10.5 Compute observation_mask for world model loss
if self.wm_coeff > 0:
resp_len = batch.batch["response_mask"].shape[1]
attn_response = batch.batch["attention_mask"][:, -resp_len:]
batch.batch["observation_mask"] = (
attn_response.bool() & ~batch.batch["response_mask"].bool()
).float()

# 11. Update actor (check critic warmup)
if self.config.trainer.critic_warmup <= self.global_steps:
with marked_timer("update_actor", timing_raw, color="red"):
Expand Down Expand Up @@ -2104,8 +2128,16 @@ def run_fastapi_server():
)
ray.init(
namespace=_server_cfg.ray.namespace,
num_gpus=_server_cfg.trainer.n_gpus_per_node, # Explicitly specify number of GPUs
num_gpus=_server_cfg.trainer.n_gpus_per_node,
ignore_reinit_error=True,
runtime_env={
"env_vars": {
"NCCL_CUMEM_ENABLE": "0",
"VLLM_DISABLE_SLEEP_MODE": "1",
"RAY_memory_usage_threshold": "0.99",
"VLLM_GPU_MEMORY_UTILIZATION": "0.15",
},
},
)
else:
# Connect to existing Ray cluster at specific address
Expand All @@ -2114,6 +2146,14 @@ def run_fastapi_server():
address=_server_cfg.ray.address,
namespace=_server_cfg.ray.namespace,
ignore_reinit_error=True,
runtime_env={
"env_vars": {
"NCCL_CUMEM_ENABLE": "0",
"VLLM_DISABLE_SLEEP_MODE": "1",
"RAY_memory_usage_threshold": "0.99",
"VLLM_GPU_MEMORY_UTILIZATION": "0.15",
},
},
)

# Verify GPU availability
Expand Down
106 changes: 106 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env bash
# Usage: ./run.sh [config] [gpus] [scheduler_port] [env_port] [steps] [model] [mode] [wm_coeff]
# Example:
# ./run.sh # all defaults
# ./run.sh alfworld_param 4,5 8782 8110 150
# ./run.sh alfworld_param 4,5 8782 8110 150 Qwen/Qwen2.5-3B-Instruct grpo_wm 0.1
# Modes:
# grpo : standard GRPO
# grpo_wm : GRPO + world model SFT loss (adds +world_model_coeff=wm_coeff)
set -euo pipefail

CONFIG="${1:-alfworld_param}"
RAW_GPUS="${2:-4,6}"
SCHEDULER_PORT="${3:-8782}"
ENV_PORT="${4:-8120}"
STEPS="${5:-600}"
MODEL="${6:-Qwen/Qwen2.5-3B-Instruct}"
MODE="${7:-grpo}"
WM_COEFF="${8:-0.1}"

# Normalize GPU list so Hydra always receives a clean list override.
# Accepted input forms: "4,6" or "[4,6]".
GPUS="${RAW_GPUS// /}"
GPUS="${GPUS#[}"
GPUS="${GPUS%]}"
if [[ -z "$GPUS" || ! "$GPUS" =~ ^[0-9]+(,[0-9]+)*$ ]]; then
echo "Invalid GPU list: '$RAW_GPUS'"
echo "Expected format: 4,6 (or [4,6])"
exit 1
fi
NUM_GPUS=$(awk -F',' '{print NF}' <<< "$GPUS")
SCHEDULER_GPU_OVERRIDE="available_gpus=[${GPUS}]"

# conda.sh may reference PS1; in non-interactive shells PS1 can be unset.
set +u
source ~/anaconda3/etc/profile.d/conda.sh
conda activate opentinker
set -u
cd "$(dirname "$0")"

EXTRA_HYDRA_ARGS=()
MODE_TAG="grpo"
case "$MODE" in
grpo)
;;
grpo_wm|grpo+wm|wm|wm_sft)
EXTRA_HYDRA_ARGS+=("+world_model_coeff=${WM_COEFF}")
MODE_TAG="grpo_wm_${WM_COEFF}"
;;
*)
echo "Unsupported mode: $MODE"
echo "Supported modes: grpo | grpo_wm"
exit 1
;;
esac

# vLLM / NCCL fixes (cumem allocator crash)
export VLLM_DISABLE_SLEEP_MODE=1
export NCCL_CUMEM_ENABLE=0
export VLLM_GPU_MEMORY_UTILIZATION=0.25

# Step 1: Scheduler
echo "=== Step 1: Scheduler (GPUs=[$GPUS], port=$SCHEDULER_PORT) ==="
if command -v lsof >/dev/null 2>&1 && lsof -iTCP:"${SCHEDULER_PORT}" -sTCP:LISTEN -t >/dev/null 2>&1; then
echo "Scheduler port ${SCHEDULER_PORT} is already in use."
echo "Stop old scheduler first, or choose another port."
echo "Hint: lsof -iTCP:${SCHEDULER_PORT} -sTCP:LISTEN -P -n"
exit 1
fi
echo "Scheduler override: ${SCHEDULER_GPU_OVERRIDE}"
ROLLOUT_TRACE_DIR=./traces TORCH_CUDA_ARCH_LIST="9.0" FLASHINFER_HOMOGENEOUS_MS=1 \
nohup python opentinker/scheduler/launch_scheduler_kill.py \
"${SCHEDULER_GPU_OVERRIDE}" gpus_per_job="${NUM_GPUS}" \
port_range=null num_ports=200 scheduler_port="${SCHEDULER_PORT}" \
> /tmp/scheduler_${SCHEDULER_PORT}.log 2>&1 &
echo "PID: $!"
sleep 12
curl -sf http://0.0.0.0:${SCHEDULER_PORT}/ > /dev/null && echo "OK" || { echo "FAIL"; exit 1; }

# Step 2: ALFWorld server
echo "=== Step 2: ALFWorld server (port=$ENV_PORT, shards=8) ==="
# Kill any stale shard processes on our port range
for ((p=ENV_PORT; p<ENV_PORT+8; p++)); do
fuser -k "${p}/tcp" 2>/dev/null || true
done
sleep 1
nohup python opentinker/environment/alfworld/alfworld_server.py \
--port "${ENV_PORT}" --shards 8 \
> /tmp/alfworld_server_${ENV_PORT}.log 2>&1 &
echo "PID: $!"
sleep 30
curl -sf http://0.0.0.0:${ENV_PORT}/health > /dev/null && echo "OK" || { echo "FAIL"; exit 1; }

# Step 3: RL training
LOG="/tmp/${CONFIG}_${MODE_TAG}_p${SCHEDULER_PORT}.log"
echo "=== Step 3: Training (config=$CONFIG, mode=$MODE, gpus=$NUM_GPUS, steps=$STEPS) ==="
nohup python opentinker/client/alfworld_rl.py \
--config-name "${CONFIG}" \
tokenizer_path="${MODEL}" \
num_gpus="${NUM_GPUS}" num_steps="${STEPS}" \
scheduler_url="http://0.0.0.0:${SCHEDULER_PORT}" \
interaction.config.env_port="${ENV_PORT}" \
"${EXTRA_HYDRA_ARGS[@]}" \
> "$LOG" 2>&1 &
echo "PID: $! | Log: $LOG"
echo "=== Done. tail -f $LOG ==="
7 changes: 7 additions & 0 deletions run_grpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
cd "$(dirname "$0")"

# GRPO baseline
# Usage: ./run_grpo.sh [gpus] [scheduler_port] [env_port] [steps]
exec ./run.sh alfworld_param "${1:-4,6}" "${2:-8782}" "${3:-8120}" "${4:-1000}" Qwen/Qwen2.5-3B-Instruct grpo
4 changes: 4 additions & 0 deletions run_grpo_wm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env bash
# GRPO + World Model SFT loss
# Usage: ./run_grpo_wm.sh [gpus] [scheduler_port] [env_port] [steps] [wm_coeff]
exec ./run.sh alfworld_param "${1:-2,7}" "${2:-8782}" "${3:-8120}" "${4:-1000}" Qwen/Qwen2.5-3B-Instruct grpo_wm "${5:-0.1}"
Loading