Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,84 @@ Endpoints = dict[str, list[Endpoint]]

`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list.

### SFTConfig

```python
class SFTConfig(TrainingArguments):
use_liger: bool = True
use_lora: bool = True
lora_rank: int = 8
lora_alpha: int = 32
lora_dropout: float = 0.0
lora_target_modules: List[str] | str | None = None
lora_modules_to_save: Optional[List[str]] = None
lora_use_rslora: bool = False
lora_config: Optional[LoraConfig] = None

dataset_name: Optional[str] = None
dataset_split: str = "train"

batch_size: int = 512
micro_batch_size: int = 8
max_seq_len: int = 2048

max_steps: int = 500
num_train_epochs: int = 1
learning_rate: float = 1e-5
adam_beta1: float = 0.9
adam_beta2: float = 0.999
weight_decay: float = 0.0
max_grad_norm: float = 1.0

use_vllm: bool = False
vllm_sample_every_n_steps: int = 100
vllm_num_samples: int = 5
vllm_server_host: str = "0.0.0.0"
vllm_server_port: int = 8000
```

Configuration class for `SFTTrainer`. Extends `transformers.TrainingArguments` with additional fields for model loading, LoRA configuration, batch parameters, and optional vLLM integration.

**Key parameters:**
- `batch_size`: Total effective batch size for training (automatically calculates `gradient_accumulation_steps`)
- `micro_batch_size`: Batch size per device per step
- `use_vllm`: Enable vLLM integration for sample generation during training (optional monitoring)
- `dataset_name` / `dataset_split`: Dataset specification (can be loaded externally and passed to trainer)

### SFTTrainer

```python
class SFTTrainer(Trainer):
def __init__(
self,
model: PreTrainedModel | str,
train_dataset: Dataset,
args: SFTConfig,
processing_class: Optional[PreTrainedTokenizerBase] = None,
eval_dataset: Optional[Dataset] = None,
**kwargs,
)
```

Supervised Fine-Tuning trainer that provides a consistent API with `RLTrainer`. Uses standard cross-entropy loss (no PPO, no advantages, no orchestrator). Supports optional vLLM integration for monitoring sample quality during training.

**Key methods:**
- `compute_loss()`: Computes cross-entropy loss (simpler than RLTrainer's PPO loss)
- `training_step()`: Standard training step without orchestrator
- `log()`: Logs metrics and samples (only when samples are available via vLLM or manual logging)
- `log_metrics()`: Tracks metrics for averaging

**Usage:**
```python
from verifiers import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("willcb/V3-wordle", split="train")
config = SFTConfig(run_name="wordle-sft", max_steps=500)
trainer = SFTTrainer(model="Qwen/Qwen3-4B-Instruct", train_dataset=dataset, args=config)
trainer.train()
```

---

## Prime CLI Plugin
Expand Down
79 changes: 79 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,85 @@ The best way to improve training is to ensure appropriate task difficulty for yo

`verifiers` is intended to be largely trainer-agnostic and is straightforward to support for any trainer which can expose an OpenAI-compatible inference client for rollouts.

### `vf.SFTTrainer`

Supervised Fine-Tuning (SFT) trainer for training models on static datasets. `SFTTrainer` provides a consistent API with `vf.RLTrainer`, making it easy to transition from SFT to RL workflows. It uses standard cross-entropy loss and supports optional vLLM integration for monitoring sample quality during training.

**Installation**: `pip install verifiers[rl]`

**Usage**:
```python
import verifiers as vf
from datasets import load_dataset

# Load dataset
dataset = load_dataset("willcb/V3-wordle", split="train")

# Create config
config = vf.SFTConfig(
run_name="wordle-sft",
max_steps=500,
learning_rate=2e-5,
batch_size=512,
micro_batch_size=8,
use_lora=True,
lora_rank=8,
)

# Create and train
trainer = vf.SFTTrainer(
model="Qwen/Qwen3-4B-Instruct",
train_dataset=dataset,
args=config,
)
trainer.train()
```

**CLI Usage**:
```bash
vf-sft @ path/to/config.toml
```

**Example TOML config**:
```toml
model = "Qwen/Qwen3-4B-Instruct"
dataset = "willcb/V3-wordle"

[sft]
run_name = "wordle-sft"
max_steps = 500
learning_rate = 2e-5
batch_size = 512
micro_batch_size = 8
use_lora = true
lora_rank = 8
use_liger = true
bf16 = true
gradient_checkpointing = false
report_to = "wandb"
```

**Key Features**:
- Same configuration structure as `RLConfig`
- Simple cross-entropy loss (no PPO, no orchestrator)
- Optional vLLM integration for sample generation (`use_vllm=True`)
- Same logging and metrics patterns as `RLTrainer`

**SFT → RL Workflow**:
```python
# Phase 1: SFT
sft_trainer = vf.SFTTrainer(model="Qwen/Qwen3-4B-Instruct", ...)
sft_trainer.train()

# Phase 2: RL on SFT checkpoint
rl_trainer = vf.RLTrainer(
model="outputs/wordle-sft/checkpoint-500", # SFT checkpoint
env=env,
args=rl_config,
)
rl_trainer.train()
```

### `vf.RLTrainer` (Legacy)

The legacy `vf.RLTrainer` still exists for educational and experimental purposes via the optional `verifiers-rl` package and the legacy RL CLI entrypoint, but it is not actively maintained. It is a compact single-node async RL trainer with a narrower feature set than production trainers. Its core implementation (`trainer.py` and `orchestrator.py` under `packages/verifiers-rl/verifiers_rl/rl/trainer/`) remains intentionally lightweight for algorithm experimentation. For production training and current guidance, use [`prime-rl`](#training-with-prime-rl).
Expand Down
1 change: 1 addition & 0 deletions packages/verifiers-rl/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" }
vf-rl = "verifiers_rl.scripts.rl:main"
vf-train = "verifiers_rl.scripts.train:main"
vf-vllm = "verifiers_rl.rl.inference.server:main"
vf-sft = "verifiers_rl.scripts.sft:main"

[tool.hatch.build.targets.wheel]
packages = ["verifiers_rl"]
4 changes: 4 additions & 0 deletions packages/verifiers-rl/verifiers_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
GRPOTrainer,
RLConfig,
RLTrainer,
SFTConfig,
SFTTrainer,
get_model,
get_model_and_tokenizer,
grpo_defaults,
Expand All @@ -14,6 +16,8 @@
"get_model_and_tokenizer",
"RLConfig",
"RLTrainer",
"SFTConfig",
"SFTTrainer",
"GRPOTrainer",
"GRPOConfig",
"grpo_defaults",
Expand Down
6 changes: 4 additions & 2 deletions packages/verifiers-rl/verifiers_rl/rl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch._dynamo

from .config import RLConfig
from .trainer import RLTrainer
from .config import RLConfig, SFTConfig
from .trainer import RLTrainer, SFTTrainer
from .utils import get_model, get_model_and_tokenizer

torch._dynamo.config.suppress_errors = True
Expand Down Expand Up @@ -31,6 +31,8 @@ def lora_defaults(**kwargs):
__all__ = [
"RLConfig",
"RLTrainer",
"SFTConfig",
"SFTTrainer",
"GRPOTrainer",
"GRPOConfig",
"grpo_defaults",
Expand Down
Loading