Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ The ROLL framework currently supports the following trackers:
1. **TensorBoard** - Visualization tool developed by Google
2. **Weights & Biases (WandB)** - Powerful machine learning experiment tracking platform
3. **SwanLab** - Next-generation AI experiment tracking tool
4. **Stdout** - Direct output to standard output
4. **Trackio** - Local-first experiment tracking library from Hugging Face
5. **Stdout** - Direct output to standard output

## Configuring Trackers

Expand Down Expand Up @@ -44,10 +45,19 @@ tracker_kwargs:
- tag1
- tag2

# Using Trackio
track_with: trackio
trackio_max_traces_per_step: 32
tracker_kwargs:
project: roll-experiments
name: experiment_name

# Using Stdout
track_with: stdout
```

When `track_with: trackio` is enabled, ROLL logs rollout generations as Trackio traces in addition to scalar metrics. RLVR rollouts are logged under `rollout/rlvr`, and agentic trajectories are logged under `train/agentic_rollouts`. Use `trackio_max_traces_per_step` to cap the number of traces recorded per logging step, or set it to `0` to disable trace logging.

## SwanLab Usage Details

### Configuring SwanLab
Expand Down Expand Up @@ -198,4 +208,4 @@ In the following time and memory metrics, `{metric_infix}` will be replaced with
- memory/cpu/`{metric_infix}`/start/rss: Actual physical memory (Resident Set Size) occupied by the process at the start of the operation.
- memory/cpu/`{metric_infix}`/start/vms: Virtual memory (Virtual Memory Size) occupied by the process at the start of the operation.
- memory/cpu/`{metric_infix}`/end/rss: Actual physical memory occupied by the process at the end of the operation.
- memory/cpu/`{metric_infix}`/end/vms: Virtual memory occupied by the process at the end of the operation.
- memory/cpu/`{metric_infix}`/end/vms: Virtual memory occupied by the process at the end of the operation.
6 changes: 5 additions & 1 deletion roll/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,16 @@ class BaseConfig(ScheduleConfig):
)
track_with: str = field(
default="tensorboard",
metadata={"help": "The type of tracker to be used for tracking, one of ['wandb', 'tensorboard', 'stdout', 'swanlab']."}
metadata={"help": "The type of tracker to be used for tracking, one of ['wandb', 'tensorboard', 'stdout', 'swanlab', 'trackio']."}
)
tracker_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Additional keyword arguments to pass to the Tracker class."}
)
trackio_max_traces_per_step: int = field(
default=32,
metadata={"help": "Maximum rollout traces to log to Trackio per step. Set to 0 to disable trace logging."},
)
max_steps: int = field(
default=500,
metadata={"help": "If > 0: set total number of pipeline steps"},
Expand Down
66 changes: 66 additions & 0 deletions roll/pipeline/agentic/agentic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,71 @@ def __init__(self, pipeline_config: AgenticConfig):
else:
self.partial_gpu_mode = False

def _trackio_traces_enabled(self) -> bool:
return (
self.pipeline_config.track_with == "trackio"
and getattr(self.pipeline_config, "trackio_max_traces_per_step", 0) > 0
)

@staticmethod
def _trace_metadata_value(value):
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return value.detach().cpu().item()
return value.detach().cpu().tolist()
if isinstance(value, np.ndarray):
if value.size == 1:
return value.item()
return value.tolist()
if isinstance(value, np.generic):
return value.item()
return value

def _log_trackio_rollout_traces(self, batch: DataProto, global_step: int, split: str = "train"):
if not self._trackio_traces_enabled() or "traj_id" not in batch.non_tensor_batch:
return

traces = []
max_traces = getattr(self.pipeline_config, "trackio_max_traces_per_step", 0)
batch_grouped = batch.group_by(keys="traj_id")
for trajectory_index, (group_name, group_batch) in enumerate(batch_grouped.items()):
if len(traces) >= max_traces:
break

if "step" in group_batch.non_tensor_batch.keys():
indices = torch.argsort(torch.from_numpy(group_batch.non_tensor_batch["step"].astype(np.int64)))
group_batch.reorder(indices)

prompt_mask = group_batch.batch["prompt_mask"]
non_prompt_mask = torch.logical_not(group_batch.batch["prompt_mask"]) * group_batch.batch["attention_mask"]
input_ids = group_batch.batch["input_ids"]
prompt_ids_list = [input_ids[i][mask.bool()] for i, mask in enumerate(prompt_mask)]
response_ids_list = [input_ids[i][mask.bool()] for i, mask in enumerate(non_prompt_mask)]
prompts = self.tokenizer.batch_decode(prompt_ids_list, skip_special_tokens=False)
responses = self.tokenizer.batch_decode(response_ids_list, skip_special_tokens=False)

messages = []
for prompt, response in zip(prompts, responses):
messages.append({"role": "user", "content": prompt})
messages.append({"role": "assistant", "content": response})

metadata = {
"split": split,
"step": global_step,
"trajectory_index": trajectory_index,
"traj_id": self._trace_metadata_value(group_name),
}
for key in ("tags", "traj_group_id", "episode_scores", "step_scores", "sample_uuid"):
if key in group_batch.non_tensor_batch:
metadata[key] = self._trace_metadata_value(group_batch.non_tensor_batch[key][0])
for key in ("response_level_rewards", "advantages"):
if key in group_batch.batch:
metadata[key] = self._trace_metadata_value(group_batch.batch[key][0])

traces.append({"messages": messages, "metadata": metadata})

self.tracker.log_traces(f"{split}/agentic_rollouts", traces, step=global_step)

@torch.no_grad()
def run(self):
# Calculate tokens-per-second system throughput
Expand Down Expand Up @@ -527,6 +592,7 @@ def run(self):
)

log_res = []
self._log_trackio_rollout_traces(batch, global_step)
batch_grouped = batch.group_by(keys="traj_id")
for group_name, group_batch in batch_grouped.items():
if "step" in group_batch.non_tensor_batch.keys():
Expand Down
55 changes: 55 additions & 0 deletions roll/pipeline/rlvr/rlvr_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,60 @@ def __init__(self, pipeline_config: RLVRConfig):
for domain in self.rewards.keys():
self.running[domain] = RunningMoments()

def _trackio_traces_enabled(self) -> bool:
return (
self.pipeline_config.track_with == "trackio"
and getattr(self.pipeline_config, "trackio_max_traces_per_step", 0) > 0
)

@staticmethod
def _trace_metadata_value(value):
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return value.detach().cpu().item()
return value.detach().cpu().tolist()
if isinstance(value, np.ndarray):
if value.size == 1:
return value.item()
return value.tolist()
if isinstance(value, np.generic):
return value.item()
return value

def _log_trackio_rollout_traces(self, batch: DataProto, global_step: int):
if not self._trackio_traces_enabled() or "prompts" not in batch.batch or "responses" not in batch.batch:
return

max_traces = min(getattr(self.pipeline_config, "trackio_max_traces_per_step", 0), len(batch))
prompts = self.tokenizer.batch_decode(batch.batch["prompts"][:max_traces], skip_special_tokens=False)
responses = self.tokenizer.batch_decode(batch.batch["responses"][:max_traces], skip_special_tokens=False)

traces = []
for sample_index, (prompt, response) in enumerate(zip(prompts, responses)):
metadata = {
"split": "train",
"step": global_step,
"sample_index": sample_index,
}
for key in ("prompt_id", "scores", "response_level_rewards"):
if key in batch.batch:
metadata[key] = self._trace_metadata_value(batch.batch[key][sample_index])
for key in ("domain", "tag", "sample_uuid"):
if key in batch.non_tensor_batch:
metadata[key] = self._trace_metadata_value(batch.non_tensor_batch[key][sample_index])

traces.append(
{
"messages": [
{"role": "user", "content": prompt},
{"role": "assistant", "content": response},
],
"metadata": metadata,
}
)

self.tracker.log_traces("rollout/rlvr", traces, step=global_step)

@torch.no_grad()
def save_metrics(self, batch):
def remove_leading_zeros(A, r_mask):
Expand Down Expand Up @@ -670,6 +724,7 @@ def run(self):
self.save_metrics(domain_batch)

batch = DataProto.concat(batch_list)
self._log_trackio_rollout_traces(batch, global_step)

if batch.batch["final_response_mask"].sum() == 0:
logger.info("Warning: final_response_mask.sum() == 0! Current step will be skipped.")
Expand Down
24 changes: 20 additions & 4 deletions roll/utils/tracking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from concurrent import futures
import json
from concurrent import futures
from functools import wraps
from typing import Optional, Dict, Any
from typing import Any, Dict, List, Optional

import torch

from roll.utils.logging import get_logger


logger = get_logger()

tracker_registry: Dict[str, Any] = {}
Expand Down Expand Up @@ -59,6 +60,9 @@ class BaseTracker:
def log(self, values: dict, step: Optional[int], **kwargs):
pass

def log_traces(self, name: str, records: List[dict], step: Optional[int] = None):
pass

def finish(self):
pass

Expand Down Expand Up @@ -184,12 +188,13 @@ def __init__(self, config: dict, **kwargs):
group = kwargs.pop("group", None)
space_id = kwargs.pop("space_id", None)
dataset_id = kwargs.pop("dataset_id", None)
tags = kwargs.pop("tags", None)
kwargs.pop("tags", None)

auto_log_gpu = kwargs.pop("auto_log_gpu", True)
gpu_log_interval = kwargs.pop("gpu_log_interval", 2)

import trackio
self.trackio = trackio

if space_id:
logger.info(f"[Trackio] Using HF Space: {space_id}")
Expand All @@ -203,7 +208,6 @@ def __init__(self, config: dict, **kwargs):
config=config,
space_id=space_id,
dataset_id=dataset_id,
tags=tags,
auto_log_gpu=auto_log_gpu,
gpu_log_interval=gpu_log_interval,
)
Expand All @@ -218,6 +222,18 @@ def log(self, values: dict, step: Optional[int], **kwargs):
def log_system(self, values: dict):
self.run.log_system(values)

def log_traces(self, name: str, records: List[dict], step: Optional[int] = None):
if not records:
return
traces = [
self.trackio.Trace(
messages=record["messages"],
metadata=record.get("metadata"),
)
for record in records
]
self.run.log({name: traces}, step=step)

def finish(self):
self.run.finish()

Expand Down
32 changes: 32 additions & 0 deletions tests/utils/test_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock

from roll.utils.tracking import BaseTracker, TrackioTracker


def test_base_tracker_log_traces_is_noop():
BaseTracker().log_traces("rollout/test", [{"messages": []}], step=1)


def test_trackio_tracker_logs_trace_records(monkeypatch):
run = MagicMock()
trace = MagicMock(return_value="trace-payload")
trackio = SimpleNamespace(init=MagicMock(return_value=run), Trace=trace)
monkeypatch.setitem(sys.modules, "trackio", trackio)

tracker = TrackioTracker(config={"model": "tiny"}, project="roll", name="trace-smoke")
records = [
{
"messages": [
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "4"},
],
"metadata": {"step": 3, "sample_index": 0},
}
]

tracker.log_traces("rollout/rlvr", records, step=3)

trace.assert_called_once_with(messages=records[0]["messages"], metadata=records[0]["metadata"])
run.log.assert_called_once_with({"rollout/rlvr": ["trace-payload"]}, step=3)