Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xtuner/v1/ray/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .tokenize import TokenizeControllerConfig
from .worker import (
RolloutConfig,
TrainingWorkerConfig,
Expand Down
55 changes: 55 additions & 0 deletions xtuner/v1/ray/config/tokenize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated

from xtuner.v1.utils import get_logger


if TYPE_CHECKING:
from xtuner.v1.ray.rollout.tokenize_controller import TokenizeController


class TokenizeControllerConfig(BaseModel):
"""Configuration for rollout tokenize controller."""

model_config = ConfigDict(extra="forbid")

num_ray_actors: Annotated[
int,
Parameter(help="Number of ray actors used by tokenize controller. 0 means local tokenize mode."),
] = 0
num_cpus_per_actor: Annotated[
int,
Parameter(help="CPU cores allocated for each tokenize ray actor."),
] = 1
num_processes_per_actor: Annotated[
int,
Parameter(help="Number of subprocesses inside each tokenize ray actor."),
] = 1
request_timeout: Annotated[
float,
Parameter(help="Timeout duration (in seconds) for tokenize requests."),
] = 300.0
enable_spread_scheduling: Annotated[
bool,
Parameter(help="Use SPREAD scheduling for tokenize ray actors when actor count > 1."),
] = True

def build(self, tokenizer_path: str) -> TokenizeController:
from xtuner.v1.ray.rollout.tokenize_controller import TokenizeController

logger = get_logger(tag="TokenizeControllerConfig")
if self.num_ray_actors <= 0:
logger.info("TokenizeController uses local tokenizer mode.")
return TokenizeController(
tokenizer_path=tokenizer_path,
num_ray_actors=self.num_ray_actors,
num_cpus_per_actor=self.num_cpus_per_actor,
num_processes_per_actor=self.num_processes_per_actor,
request_timeout=self.request_timeout,
enable_spread_scheduling=self.enable_spread_scheduling,
)
11 changes: 10 additions & 1 deletion xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import Any, List, Literal, Optional, Union

from cyclopts import Group, Parameter
from pydantic import BaseModel, ConfigDict, PrivateAttr
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from typing_extensions import Annotated

from xtuner.v1.utils import get_logger

from .tokenize import TokenizeControllerConfig


worker_group = Group("worker", help="Types of workers available.")
train_group = Group("Training", sort_key=90, help="Training worker configuration.")
Expand Down Expand Up @@ -263,6 +265,13 @@ class RolloutConfig(BaseModel):
help="Use float32 for language model head.",
),
] = False
tokenize_controller_config: Annotated[
TokenizeControllerConfig,
Parameter(
group=infer_group,
help="Configuration of tokenize controller used by rollout controller.",
),
] = Field(default_factory=TokenizeControllerConfig)
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"
_logged_server_urls_per_engine: bool = PrivateAttr(default=False)

Expand Down
12 changes: 4 additions & 8 deletions xtuner/v1/ray/environment/lagent/llms/controller_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ResponseParser,
)
from xtuner.v1.ray.environment.lagent.schema import AgentMessage
from xtuner.v1.ray.environment.lagent.tokenize import tokenize
from xtuner.v1.ray.rollout.controller import RolloutController


Expand All @@ -26,19 +25,16 @@ def __init__(
reasoning_parser: Optional[ResponseParser] = None,
tool_call_parser: Optional[ResponseParser] = None,
):
assert rollout_controller is not None or (
placement_group and rollout_cfg
), "Either rollout_controller or placement_group and rollout_cfg must be provided."
assert rollout_controller is not None or (placement_group and rollout_cfg), (
"Either rollout_controller or placement_group and rollout_cfg must be provided."
)
if rollout_controller:
self.rollout_controller = rollout_controller
self.rollout_cfg = ray.get(rollout_controller.get_rollout_info.remote())["rollout_config"] # type: ignore[call-overload, attr-defined]
else:
self.rollout_controller = RolloutController.remote(rollout_cfg, placement_group) # type: ignore[attr-defined]
self.rollout_cfg = rollout_cfg

from transformers import AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(self.rollout_cfg.tokenizer_path, trust_remote_code=True)
self.sample_params = sample_params or SampleParams()
# default parsers
self.reasoning_parser = (
Expand All @@ -50,7 +46,7 @@ def __init__(

async def chat(self, messages, tools: Optional[List[Dict]] = None, **kwargs):
sample_params = self.sample_params.model_copy(update=kwargs)
inputs = tokenize(self.tokenizer, messages, tools)
inputs = await self.rollout_controller.tokenize.remote(messages, tools) # type: ignore[attr-defined]
if len(inputs["input_ids"]) >= self.rollout_cfg.context_length:
response = RLRolloutResponseItem(finish_reason="length")
else:
Expand Down
13 changes: 13 additions & 0 deletions xtuner/v1/ray/environment/lagent/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ class LagentChatCompletionRequest(BaseModel):
# XTuner-specific generation parameters
sample_params: SampleParams = Field(default_factory=SampleParams)
extra_params: Dict[str, Any] = Field(default_factory=dict)


class LagentTokenizeRequest(BaseModel):
model: str = ""
messages: List[Dict[str, Any]]
tools: List[Any] = Field(default_factory=list)


class LagentTokenizeResponse(BaseModel):
input_ids: List[int]
labels: List[int]
logprobs: List[float]
routed_experts: Optional[Any] = None
18 changes: 13 additions & 5 deletions xtuner/v1/ray/rollout/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from fastapi import FastAPI
from ray.util.placement_group import PlacementGroup

from transformers import AutoTokenizer
from xtuner.v1.data_proto.rl_data import (
RLRolloutResponseItem,
RolloutExtraParams,
Expand Down Expand Up @@ -138,7 +137,7 @@ def __init__(
self.active_rollout_workers: List[RolloutWorker] = []
tokenizer_path = infer_config.tokenizer_path
assert tokenizer_path is not None, "tokenizer_path must be set before creating RolloutController"
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
self.tokenize_controller = infer_config.tokenize_controller_config.build(tokenizer_path=tokenizer_path)
self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
self._get_worker_cls(), infer_config, placement_group
)
Expand Down Expand Up @@ -244,6 +243,9 @@ def get_rollout_info(self):
api_server_url=getattr(self, "api_server_url", None),
)

async def tokenize(self, messages: List[Dict[str, Any]], tools: Optional[List[Any]] = None) -> Dict[str, Any]:
return await self.tokenize_controller.tokenize(messages, tools)

def init_workers(self):
"""Initializes and configures the pool of RolloutWorker actors.

Expand Down Expand Up @@ -481,12 +483,13 @@ def start_api_server(self, host: str = "0.0.0.0", port: int = 8000):
LagentChatCompletionMessage,
LagentChatCompletionRequest,
LagentChoice,
LagentTokenizeRequest,
LagentTokenizeResponse,
)
from xtuner.v1.ray.environment.lagent.tokenize import tokenize

@app.post("/v1/chat/completions")
async def chat_completions(request: LagentChatCompletionRequest):
inputs = tokenize(self.tokenizer, request.messages, request.tools)
inputs = await self.tokenize(request.messages, request.tools)
response: RLRolloutResponseItem = await self.rollout(
prompt=request.messages,
input_ids=inputs["input_ids"],
Expand Down Expand Up @@ -521,6 +524,11 @@ async def chat_completions(request: LagentChatCompletionRequest):
),
).model_dump()

@app.post("/v1/tokenize")
async def tokenize(request: LagentTokenizeRequest):
inputs = await self.tokenize(request.messages, request.tools)
return LagentTokenizeResponse(**inputs).model_dump()

config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
server_thread = threading.Thread(target=server.run, daemon=True)
Expand Down Expand Up @@ -662,5 +670,5 @@ def shutdown(self, block=True):
Args:
block (bool): Whether to block until the operation completes.
"""
return self._broadcast_to_active_workers("shutdown", block)
self.tokenize_controller.shutdown()
return self._broadcast_to_active_workers("shutdown", block)
140 changes: 140 additions & 0 deletions xtuner/v1/ray/rollout/tokenize_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import asyncio
from functools import partial
from typing import Any, Dict, List, Optional

import ray

from transformers import AutoTokenizer
from xtuner.v1.ray.environment.lagent.tokenize import tokenize as lagent_tokenize
from xtuner.v1.utils.executor import SharedPoolExecutor
from xtuner.v1.utils import get_logger


_PROCESS_TOKENIZER = None


def _tokenize_in_process(
task: tuple[List[Dict[str, Any]], Optional[List[Any]]],
tokenizer_path: str,
enable_interleaved_thinking: bool,
enable_thinking: bool,
) -> Dict[str, Any]:
messages, tools = task
global _PROCESS_TOKENIZER
if _PROCESS_TOKENIZER is None:
_PROCESS_TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
assert _PROCESS_TOKENIZER is not None, "Process tokenizer is not initialized."
return lagent_tokenize(
_PROCESS_TOKENIZER,
messages,
tools=tools,
enable_interleaved_thinking=enable_interleaved_thinking,
enable_thinking=enable_thinking,
)


class TokenizeWorker:
def __init__(
self,
tokenizer_path: str,
num_processes: int = 1,
enable_interleaved_thinking: bool = True,
enable_thinking: bool = True,
):
self.logger = get_logger(tag="TokenizeWorker")
self.tokenizer_path = tokenizer_path
self.enable_interleaved_thinking = enable_interleaved_thinking
self.enable_thinking = enable_thinking
self.num_processes = max(1, num_processes)
self.pool: Optional[SharedPoolExecutor] = None
self.tokenizer = None

if self.num_processes > 1:
self.pool = SharedPoolExecutor(
fn=_tokenize_in_process,
partial_kwargs={
"tokenizer_path": tokenizer_path,
"enable_interleaved_thinking": self.enable_interleaved_thinking,
"enable_thinking": self.enable_thinking,
},
max_workers=self.num_processes,
mp_context="fork",
)
self.logger.info(f"Tokenize worker starts process pool, num_processes={self.num_processes}")
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)

async def tokenize(self, messages: List[Dict[str, Any]], tools: Optional[List[Any]] = None) -> Dict[str, Any]:
if self.pool is not None:
return await asyncio.wrap_future(self.pool.submit((messages, tools)))

assert self.tokenizer is not None, "Tokenizer is not initialized."
tokenize_call = partial(
lagent_tokenize,
self.tokenizer,
messages,
tools=tools,
enable_interleaved_thinking=self.enable_interleaved_thinking,
enable_thinking=self.enable_thinking,
)
return tokenize_call()

def shutdown(self):
if self.pool is not None:
self.pool.shutdown()
self.pool = None


class TokenizeController:
def __init__(
self,
tokenizer_path: str,
num_ray_actors: int = 0,
num_cpus_per_actor: int = 1,
num_processes_per_actor: int = 1,
request_timeout: float = 300.0,
enable_spread_scheduling: bool = True,
):
self.logger = get_logger(tag="TokenizeController")
self.request_timeout = request_timeout
self._lock = asyncio.Lock()
self._next_actor_idx = 0
self._actors: List[ray.actor.ActorHandle] = []

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)

if num_ray_actors <= 0:
return

worker_cls = ray.remote(TokenizeWorker)
actor_options: Dict[str, Any] = dict(num_cpus=num_cpus_per_actor)
if enable_spread_scheduling and num_ray_actors > 1:
actor_options["scheduling_strategy"] = "SPREAD"
self.logger.info("TokenizeController enables SPREAD scheduling for tokenize workers.")
for _ in range(num_ray_actors):
actor = worker_cls.options(**actor_options).remote(
tokenizer_path=tokenizer_path,
num_processes=num_processes_per_actor,
)
self._actors.append(actor)
self.logger.info(
"TokenizeController starts %d ray actors, %d processes per actor.",
len(self._actors),
max(1, num_processes_per_actor),
)

async def tokenize(self, messages: List[Dict[str, Any]], tools: Optional[List[Any]] = None) -> Dict[str, Any]:
if not self._actors:
return lagent_tokenize(self.tokenizer, messages, tools=tools)

async with self._lock:
actor = self._actors[self._next_actor_idx]
self._next_actor_idx = (self._next_actor_idx + 1) % len(self._actors)

response_ref = actor.tokenize.remote(messages, tools)
return await asyncio.wait_for(asyncio.shield(response_ref), timeout=self.request_timeout)

def shutdown(self):
tasks = [actor.shutdown.remote() for actor in self._actors]
if tasks:
ray.get(tasks)
9 changes: 1 addition & 8 deletions xtuner/v1/train/cli/rl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import time
from pathlib import Path
from typing import Annotated
Expand Down Expand Up @@ -40,13 +39,7 @@ def main(
config: Annotated[Path, Parameter(group=Group("config-path", sort_key=0))],
):
if not ray.is_initialized():
if os.getenv("RAY_MASTER_ADDR"):
master_addr = os.getenv("RAY_MASTER_ADDR", "127.0.0.1")
client_port = os.getenv("RAY_CLIENT_PORT", "10001")
ray_head_address = f"ray://{master_addr}:{client_port}"
ray.init(address=ray_head_address)
else:
ray.init(num_cpus=128)
ray.init(address="auto")

# if os.getenv("XTUNER_RL_MEM_DIR"):
# print("Start to monitor actor memory")
Expand Down
Loading