Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,9 @@ FetchContent_MakeAvailable(repo-cutlass)
FetchContent_Declare(
yaml-cpp
GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git
GIT_TAG 0.8.0
GIT_TAG 65c1c270dbe7eec37b2df2531d7497c4eea79aee
GIT_PROGRESS TRUE
USES_TERMINAL_DOWNLOAD TRUE
PATCH_COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/yaml-cpp_cmake_policy.patch
UPDATE_DISCONNECTED 1
)
set(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL "Build static library of yaml-cpp")
FetchContent_MakeAvailable(yaml-cpp)
Expand All @@ -87,7 +85,6 @@ FetchContent_Declare(
GIT_SUBMODULES "3rdparty/dlpack"
GIT_PROGRESS TRUE
USES_TERMINAL_DOWNLOAD TRUE
UPDATE_DISCONNECTED 1
)

FetchContent_GetProperties(xgrammar)
Expand All @@ -110,6 +107,7 @@ endif()

# the environment variable
# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libasan.so.6:/usr/lib/x86_64-linux-gnu/libstdc++.so.6
# must be set at runtime
# https://github.com/google/sanitizers/issues/1322
if (LMDEPLOY_ASAN_ENABLE)
Expand Down Expand Up @@ -333,6 +331,8 @@ if (MSVC)
CMAKE_CUDA_FLAGS_RELWITHDEBINFO)
string(REGEX REPLACE "-Wall" " /W0 " ${flag_var} "${${flag_var}}")
endforeach()
# avoid min/max macro in "windows.h" conflict with std::min/std::max
add_definitions(-DNOMINMAX=1)
endif()

include_directories(
Expand Down
2 changes: 2 additions & 0 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def parse_args():
ArgumentHelper.model_format(tb_group, default='hf')
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
ArgumentHelper.async_(tb_group)
ArgumentHelper.communicator(tb_group)

args = parser.parse_args()
Expand All @@ -352,6 +353,7 @@ def main():
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
async_=args.async_,
enable_prefix_caching=args.enable_prefix_caching,
dtype=args.dtype,
communicator=args.communicator,
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def add_parser_chat():
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.cp(tb_group)
ArgumentHelper.async_(tb_group)

# speculative decoding
ArgumentHelper.add_spec_group(parser)
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def add_parser_api_server():
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
ArgumentHelper.async_(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.dist_init_addr(tb_group)

Expand Down Expand Up @@ -262,6 +263,7 @@ def api_server(args):
max_prefill_token_num=args.max_prefill_token_num,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters,
async_=args.async_,
communicator=args.communicator,
enable_metrics=not args.disable_metrics,
hf_overrides=args.hf_overrides)
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,16 @@ def max_prefill_iters(parser):
default=1,
help='the max number of forward passes in prefill stage')

@staticmethod
def async_(parser):
return parser.add_argument('--async',
type=int,
default=1,
choices=[0, 1],
dest='async_',
help='Enable async execution (default: 1, enabled). '
'Set to 0 to disable async mode, 1 to enable it.')

@staticmethod
def max_prefill_token_num(parser):
return parser.add_argument('--max-prefill-token-num',
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class TurbomindEngineConfig:
"Dynamic SplitFuse"-like scheduling
max_prefill_iters(int): the max number of forward pass during prefill
stage
async_ (int): enable async execution, default to 1 (enabled).
When set to 0, async mode is disabled. When set to 1, async mode is enabled.
devices(List[int]): the used devices
empty_init (bool): Whether to load the model weights, you should set
it to True if you want to update weights after create the pipeline
Expand Down Expand Up @@ -264,6 +266,7 @@ class TurbomindEngineConfig:
max_prefill_token_num: int = 8192
num_tokens_per_iter: int = 0
max_prefill_iters: int = 1
async_: int = 1
devices: Optional[List[int]] = None
empty_init: bool = False
communicator: str = 'nccl'
Expand All @@ -280,6 +283,7 @@ def __post_init__(self):
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_tokens_per_iter >= 0, 'invalid num_tokens_per_iter'
assert self.async_ in (0, 1), 'async_ must be 0 (disabled) or 1 (enabled)'


@dataclass
Expand Down Expand Up @@ -442,6 +446,7 @@ class ResponseType(enum.Enum):
INTERNAL_ENGINE_ERROR = enum.auto()
CANCEL = enum.auto()
PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE = enum.auto()
NO_QUEUE = enum.auto()


@dataclass
Expand Down
95 changes: 30 additions & 65 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
from queue import Queue
from typing import Any, Dict, List, Optional

import numpy as np
import pybase64
import torch
import yaml
from torch.nn.utils.rnn import pad_sequence

import lmdeploy
from lmdeploy.messages import EngineOutput, GenerationConfig, ResponseType, ScheduleMetrics, TurbomindEngineConfig
Expand Down Expand Up @@ -195,28 +193,22 @@ def _load_weights(self):
def _process_weights(self):
"""Process weight."""
with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
for _ in e.map(self.model_comm.process_weight, range(self.gpu_count), ranks):
for _ in e.map(self.model_comm.process_weight, range(self.gpu_count)):
pass

def _create_engine(self):
"""Create engine."""
with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
for _ in e.map(self.model_comm.create_engine, range(self.gpu_count), ranks):
for _ in e.map(self.model_comm.create_engine, range(self.gpu_count)):
pass
self._engine_created = True

def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

engine_cfg = self.config_dict['engine_config']
self.node_id = engine_cfg['node_rank']

# create weight
def _create_weight_func(device_id):
rank = self.node_id * self.gpu_count + device_id
model_comm.create_shared_weights(device_id, rank)
model_comm.create_weights(device_id)

with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
Expand All @@ -233,8 +225,7 @@ def _get_model_params(self):
tm_params.clear()

def _get_params(device_id, que):
rank = self.node_id * self.gpu_count + device_id
out = model_comm.get_params(device_id, rank)
out = model_comm.get_weights(device_id)
que.put(out)

que = Queue()
Expand Down Expand Up @@ -266,12 +257,6 @@ def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: Tu
# update some attributes of `engine_config` which depends on
# `session_len`
self.engine_config = engine_config
if engine_config.max_prefill_token_num is not None \
and engine_config.num_tokens_per_iter == 0:
self.engine_config.num_tokens_per_iter = \
engine_config.max_prefill_token_num
self.engine_config.max_prefill_iters = (self.config.session_len + engine_config.max_prefill_token_num -
1) // engine_config.max_prefill_token_num

# pack `self.config` and `self.engine_config` into a dict
self.config_dict = self.config.to_dict()
Expand All @@ -290,9 +275,9 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):

self._postprocess_config(tm_model.tm_config, engine_config)

model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',
config=yaml.safe_dump(self.config_dict),
weight_type=self.config.model_config.weight_type)
model_comm = _tm.TurboMind.create(model_dir='',
config=yaml.safe_dump(self.config_dict),
weight_type=self.config.model_config.weight_type)

# create empty weight
self._create_weight(model_comm)
Expand All @@ -311,8 +296,7 @@ def wakeup(self, tags: Optional[list[str]] = None):
if tags is None:
tags = ['weights', 'kv_cache']
with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count, ranks):
for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count):
pass

def update_params(self, request: UpdateParamsRequest):
Expand Down Expand Up @@ -501,7 +485,7 @@ def _func(out: EngineOutput, step: int, **kwargs):
out.req_metrics = RequestMetrics(token_timestamp=time.time())
else:
events = [
EngineEvent(EventType.QUEUED, metrics.enque_time / 1000000),
EngineEvent(EventType.QUEUED, metrics.enqueue_time / 1000000),
EngineEvent(EventType.SCHEDULED, metrics.scheduled_time / 1000000),
]
out.req_metrics = RequestMetrics(token_timestamp=time.time(), engine_events=events)
Expand Down Expand Up @@ -547,7 +531,7 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea

# create model instances
lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False)
self._model_inst = None if lazy_init else self._create_model_instance(0)
self._model_inst = None if lazy_init else self._create_model_instance()

self.config = config
self.lock = None
Expand All @@ -564,17 +548,18 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
7: ResponseType.FINISH,
8: ResponseType.CANCEL,
9: ResponseType.PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE,
10: ResponseType.NO_QUEUE,
-1: ResponseType.INTERNAL_ENGINE_ERROR,
}

@property
def model_inst(self):
if self._model_inst is None:
self._model_inst = self._create_model_instance(0)
self._model_inst = self._create_model_instance()
return self._model_inst

def _create_model_instance(self, device_id):
model_inst = self.tm_model.model_comm.create_model_instance(device_id)
def _create_model_instance(self):
model_inst = self.tm_model.model_comm.create_request()
return model_inst

def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor], gen_config: GenerationConfig,
Expand All @@ -598,47 +583,27 @@ def _get_offset(type):

def prepare_embeddings(self, input_embeddings=None, input_embedding_ranges=None):
"""Convert embeddings."""
if input_embeddings is None:
if not input_embeddings:
return None, None

assert isinstance(input_embeddings, List)
assert isinstance(input_embedding_ranges, List)
assert len(input_embeddings) == len(input_embedding_ranges)
if not isinstance(input_embeddings[0], (list, type(None))):
input_embeddings = [input_embeddings]
input_embedding_ranges = [input_embedding_ranges]

if all([isinstance(x, type(None)) for x in input_embeddings]):
return None, None
length = sum([x.shape[0] for x in input_embeddings])

_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16)
dtype = _MAP[self.tm_model.config.model_config.data_type]

values = torch.empty((length, input_embeddings[0].shape[-1]), dtype=dtype, device='cpu')
ranges = torch.tensor(input_embedding_ranges, dtype=torch.int32, device='cpu')

offset = 0
for embeds in input_embeddings:
values[offset:offset + embeds.shape[0]].copy_(embeds)
offset += embeds.shape[0]

hidden_dim = None
for embeddings in input_embeddings:
if embeddings is not None:
hidden_dim = embeddings[0].squeeze().shape[-1]
break
assert hidden_dim is not None

# construct input_embeddings
for i in range(len(input_embeddings)):
item = input_embeddings[i] or []
# convert to torch.Tensor if input is np.ndarray
if item and isinstance(item[0], np.ndarray):
item = [torch.from_numpy(x).squeeze() for x in item]
# convert to lookup table type
_MAP = dict(float=torch.float, bfloat16=torch.bfloat16, float16=torch.float16, fp8=torch.bfloat16)
dtype = _MAP.get(self.tm_model.config.weight_type, torch.float16)
item = [x.to(dtype=dtype) for x in item]
item = item or [torch.zeros(0, hidden_dim, dtype=dtype)]
input_embeddings[i] = item
input_embeddings = [torch.cat(x) for x in input_embeddings]
input_embeddings = pad_sequence(input_embeddings, batch_first=True)
input_embeddings = input_embeddings.reshape(input_embeddings.shape[0], -1).view(torch.int8)
# construct input_embedding_ranges
for i in range(len(input_embedding_ranges)):
item = input_embedding_ranges[i] or []
item = torch.IntTensor(item).reshape(-1, 2)
input_embedding_ranges[i] = item
input_embedding_ranges = pad_sequence(input_embedding_ranges, batch_first=True, padding_value=-1)

return input_embeddings, input_embedding_ranges
return values, ranges

def prepare_mrope(self, input_meta: Dict[str, Any], input_len: int):
mrope_position_ids = input_meta['mrope_position_ids']
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/vl/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_st
# collect image features from messages
features = [x['content'] for x in messages if x['role'] == 'forward']
features = features[0]
features = [x.cpu().numpy() for x in features]
features = [x.cpu() for x in features]
# split prompt into segments and validate data
segs = prompt.split(IMAGE_TOKEN)
assert len(segs) == len(features) + 1, (f'the number of {IMAGE_TOKEN} is not equal '
Expand Down
21 changes: 16 additions & 5 deletions src/turbomind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@
add_subdirectory(utils)
add_subdirectory(core)
add_subdirectory(kernels)
add_subdirectory(layers)
add_subdirectory(comm)
add_subdirectory(generation)
add_subdirectory(models)
add_subdirectory(engine)
if(BUILD_PYT)
add_subdirectory(th_op)
endif()

if(BUILD_PY_FFI)
add_subdirectory(python)
endif()
add_subdirectory(triton_backend)

add_library(turbomind STATIC turbomind.cc)
set_property(TARGET turbomind PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(turbomind PUBLIC
engine
models
device_comm
host_comm
core
memory_utils
nvtx_utils
CUDA::cublasLt
CUDA::cudart
yaml-cpp::yaml-cpp)
3 changes: 2 additions & 1 deletion src/turbomind/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ add_library(core STATIC
layout.cc
tensor.cc
tensor.cu
module.cc)
module.cc
copy.cc)

target_link_libraries(core PUBLIC cuda_utils logger CUDA::cudart CUDA::cuda_driver)

Expand Down
1 change: 1 addition & 0 deletions src/turbomind/core/allocator.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <algorithm>
#include <functional>

#include "src/turbomind/core/check.h"
Expand Down
Loading