Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions dev/run_yes_no_maybe_kl_advantage_tinker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Launch yes-no-maybe-kl-advantage-tinker training on SkyPilot (Kubernetes).

Usage:
uv run dev/run_yes_no_maybe_kl_advantage_tinker.py
uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --fast
uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --base-model Qwen/Qwen2.5-7B-Instruct
"""

import argparse
import os
import textwrap

from dotenv import load_dotenv
import sky
from sky import ClusterStatus

load_dotenv()

parser = argparse.ArgumentParser(
description="Launch yes-no-maybe KL advantage training (Tinker) on SkyPilot."
)
parser.add_argument(
"--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)."
)
parser.add_argument(
"--base-model", type=str, default="meta-llama/Llama-3.1-8B-Instruct"
)
parser.add_argument("--num-steps", type=int, default=20)
parser.add_argument("--kl-penalty-coef", type=float, default=0.1)
parser.add_argument("--accelerator", type=str, default="H200:1")
parser.add_argument("--cluster-name", type=str, default=None)
parser.add_argument(
"--kl-ref-step",
type=int,
default=None,
help="Checkpoint step of training model to use as KL reference",
)
args = parser.parse_args()

cluster_name = args.cluster_name or f"ynm-tinker-kl-{args.kl_penalty_coef}"
cluster_prefix = os.environ.get("CLUSTER_PREFIX")
if cluster_prefix:
cluster_name = f"{cluster_prefix}-{cluster_name}"

setup_script = textwrap.dedent("""\
echo 'Setting up environment...'
apt install -y nvtop
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
""")

kl_ref_env = ""
if args.kl_ref_step is not None:
kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} "

run_script = textwrap.dedent(f"""\
source $HOME/.local/bin/env
cd ~/sky_workdir
{kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra tinker dev/yes-no-maybe-kl-advantage-tinker.py
""")

task = sky.Task(
name="yes-no-maybe-kl-advantage-tinker",
setup=setup_script,
run=run_script,
workdir=".",
)
task.set_resources(
sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes())
)
task.set_file_mounts(
{
"~/sky_workdir/.env": ".env",
}
)

print(f"Launching on cluster: {cluster_name}")
print(f" base_model: {args.base_model}")
print(f" accelerator: {args.accelerator}")
print(f" num_steps: {args.num_steps}")
print(f" kl_penalty_coef: {args.kl_penalty_coef}")
if args.kl_ref_step is not None:
print(f" kl_ref_step: {args.kl_ref_step}")

# Cancel any existing jobs on this cluster
cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP:
print(f"Cluster {cluster_name} is UP. Canceling any active jobs...")
sky.stream_and_get(sky.cancel(cluster_name, all=True))

job_id, _ = sky.stream_and_get(
sky.launch(
task,
cluster_name=cluster_name,
retry_until_up=True,
idle_minutes_to_autostop=60,
down=True,
fast=args.fast,
)
)

print(f"Job submitted (ID: {job_id}). Streaming logs...")
exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
print(f"Job {job_id} finished with exit code {exit_code}.")
111 changes: 111 additions & 0 deletions dev/yes-no-maybe-kl-advantage-tinker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Yes-no-maybe training with KL-penalized advantage adjustment (Tinker backend).

Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted
more from the reference model get reduced advantages, while tokens that have
drifted less get increased advantages.

Uses meta-llama/Llama-3.1-8B-Instruct as the base model (trained via Tinker).
"""

import asyncio
from itertools import permutations
import os
import random
import string

from dotenv import load_dotenv
import openai

import art
from art.tinker_native import TinkerNativeBackend


async def rollout(
client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str
) -> art.Trajectory:
messages: art.Messages = [
{
"role": "user",
"content": prompt,
}
]
chat_completion = await client.chat.completions.create(
messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100
)
choice = chat_completion.choices[0]
content = choice.message.content
assert isinstance(content, str)
if content == "yes":
reward = 0.5
elif content == "no":
reward = 0.75
elif content == "maybe":
reward = 1.0
else:
reward = 0.0
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)


def with_quotes(w: str) -> str:
return f"'{w}'"


async def main():
load_dotenv()

backend = TinkerNativeBackend()
base_model = os.environ.get("BASE_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1"))
random_suffix = "".join(random.choices(string.ascii_lowercase, k=4))
model = art.TrainableModel(
name=os.environ.get("MODEL_NAME", f"tinker-{random_suffix}-{kl_penalty_coef}"),
project="yes-no-maybe",
base_model=base_model,
)
await model.register(backend)

kl_penalty_reference_step: int | None = (
int(os.environ["KL_REF_STEP"])
if os.environ.get("KL_REF_STEP") is not None
else None
)

prompts = [
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
for prefix in ["respond", "just respond"]
for use_quotes in [True, False]
for words in (
list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
)
]

openai_client = model.openai_client()
max_steps = int(os.environ.get("NUM_STEPS", "20"))
start_step = await model.get_step()
for step in range(start_step, start_step + max_steps):
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(openai_client, model, prompt) for _ in range(32)
)
for prompt in prompts
)
)
result = await backend.train(
model,
train_groups,
learning_rate=1e-4,
kl_penalty_coef=kl_penalty_coef,
kl_penalty_reference_step=kl_penalty_reference_step,
)
await model.log(
train_groups,
metrics=result.metrics,
step=result.step,
split="train",
)
print(f"step {result.step}: {result.metrics}")


if __name__ == "__main__":
asyncio.run(main())
5 changes: 4 additions & 1 deletion dev/yes-no-maybe-kl-advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import asyncio
from itertools import permutations
import os
import random
import string

from dotenv import load_dotenv
import openai
Expand Down Expand Up @@ -54,8 +56,9 @@ async def main():
backend = LocalBackend()
base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1"))
random_suffix = "".join(random.choices(string.ascii_lowercase, k=4))
model = art.TrainableModel(
name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"),
name=os.environ.get("MODEL_NAME", f"local-{random_suffix}-{kl_penalty_coef}"),
project="yes-no-maybe",
base_model=base_model,
)
Expand Down
3 changes: 3 additions & 0 deletions src/art/_backend_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def build_rl_train_configs(
max_negative_advantage_importance_sampling_weight: float | None = None,
kimi_k2_tau: float | None = None,
kl_penalty_coef: float = 0.0,
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner",
allow_training_without_logprobs: bool | None = None,
plot_tensors: bool | None = None,
truncated_importance_sampling: float | None = None,
Expand All @@ -40,11 +41,13 @@ def build_rl_train_configs(
config = TrainConfig(
learning_rate=learning_rate,
kl_penalty_coef=kl_penalty_coef,
kl_penalty_source=kl_penalty_source,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"kl_penalty_source": kl_penalty_source,
"mask_prob_ratio": mask_prob_ratio,
"ppo": ppo,
"precalculate_logprobs": precalculate_logprobs,
Expand Down
2 changes: 2 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class TrainConfig(TypedDict, total=False):
]
kimi_k2_tau: float | None
kl_penalty_coef: float
kl_penalty_reference_step: int | None
kl_penalty_source: Literal["current_learner", "sample"]
kl_ref_adapter_path: str | None
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
Expand Down
8 changes: 8 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ async def train( # type: ignore[override]
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_ref_adapter_path: str | None = None,
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner",
epsilon: float | None = None,
epsilon_high: float | None = None,
# Advantage computation
Expand Down Expand Up @@ -705,6 +706,11 @@ async def train( # type: ignore[override]
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
checkpoint to use as the KL reference. Alternative to
kl_penalty_reference_step.
kl_penalty_source: Which policy's logprobs to compare against the
reference when building the centered KL penalty. Use
"current_learner" to match the original ART implementation, or
"sample" to shape from the rollout policy logprobs, which is
usually better for async/off-policy workloads.
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
advantage_balance: Balance between negative and positive advantages
Expand Down Expand Up @@ -755,6 +761,7 @@ async def train( # type: ignore[override]
scale_rewards = False
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")
assert kl_penalty_source in {"current_learner", "sample"}
if (
self._requires_explicit_packed_sequence_length
and packed_sequence_length is None
Expand Down Expand Up @@ -785,6 +792,7 @@ async def train( # type: ignore[override]
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
kl_penalty_source=kl_penalty_source,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
Expand Down
9 changes: 8 additions & 1 deletion src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ def loss_fn(
kl_policy_ref: torch.Tensor | None = None
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
if kl_penalty_coef > 0 and ref_logprobs is not None:
kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask
match experimental_config.get("kl_penalty_source", "current_learner"):
case "sample":
kl_source_logprobs = old_logprobs.detach()
case "current_learner":
kl_source_logprobs = new_logprobs.detach()
case other:
raise AssertionError(other)
kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
advantages = advantages + kl_penalty
Expand Down
Loading
Loading