Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lightllm/distributed/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import dataclasses
from datetime import timedelta
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps
import time
from typing import Optional, Union, Dict, Deque, Tuple, Any
from collections import deque
Expand Down Expand Up @@ -85,7 +85,7 @@ def send_obj(self, obj: Any):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{self.dest_id}/{self.send_dst_counter}"
self.store.set(key, pickle.dumps(obj))
self.store.set(key, safe_pickle_dumps(obj))
self.send_dst_counter += 1
self.entries.append((key, time.time()))

Expand All @@ -102,7 +102,7 @@ def expire_data(self):

def recv_obj(self) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}"))
obj = safe_pickle_loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}"))
self.recv_src_counter += 1
return obj

Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import base64
import os
from io import BytesIO
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads
import setproctitle

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
Expand Down Expand Up @@ -307,7 +307,7 @@ async def register_and_keep_alive(websocket: WebSocket):
while True:
# 等待接收消息,设置超时为10秒
data = await websocket.receive_bytes()
obj = pickle.loads(data)
obj = safe_pickle_loads(data)
await g_objs.httpserver_manager.put_to_handle_queue(obj)

except (WebSocketDisconnect, Exception, RuntimeError) as e:
Expand All @@ -328,7 +328,7 @@ async def kv_move_status(websocket: WebSocket):
while True:
# 等待接收消息,设置超时为10秒
data = await websocket.receive_bytes()
upkv_status = pickle.loads(data)
upkv_status = safe_pickle_loads(data)
logger.info(f"received upkv_status {upkv_status} from {(client_ip, client_port)}")
await g_objs.httpserver_manager.update_req_status(upkv_status)
except (WebSocketDisconnect, Exception, RuntimeError) as e:
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/config_server/api_http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import asyncio
import base64
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps
import setproctitle
import multiprocessing as mp
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
Expand Down Expand Up @@ -53,7 +53,7 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
client_ip, client_port = websocket.client
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes())
registered_pd_master_obj: PD_Master_Obj = safe_pickle_loads(await websocket.receive_bytes())
logger.info(f"received registered_pd_master_obj {registered_pd_master_obj}")
with registered_pd_master_obj_lock:
registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj
Expand All @@ -75,7 +75,7 @@ async def websocket_endpoint(websocket: WebSocket):
@app.get("/registered_objects")
async def get_registered_objects():
with registered_pd_master_obj_lock:
serialized_data = pickle.dumps(registered_pd_master_objs)
serialized_data = safe_pickle_dumps(registered_pd_master_objs)
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
return {"data": base64_encoded}

Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/core/objs/shm_objs_io_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps
from lightllm.server.core.objs.atomic_lock import AtomicShmLock
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_unique_server_name
Expand Down Expand Up @@ -41,14 +41,14 @@ def is_ready(self):
return self.int_view[0] == self.node_world_size

def write_obj(self, obj):
obj_bytes = pickle.dumps(obj)
obj_bytes = safe_pickle_dumps(obj)
self.int_view[1] = len(obj_bytes)
self.shm.buf[8 : 8 + len(obj_bytes)] = obj_bytes
return

def read_obj(self):
bytes_len = self.int_view[1]
obj = pickle.loads(self.shm.buf[8 : 8 + bytes_len])
obj = safe_pickle_loads(self.shm.buf[8 : 8 + bytes_len])
return obj

def _create_or_link_shm(self):
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ async def generate(
f"nixl prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}"
)
await nixl_pd_upload_websocket.send(
pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids))
safe_pickle_dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids))
)
try:
await asyncio.wait_for(nixl_pd_event.wait(), timeout=80)
Expand All @@ -317,7 +317,7 @@ async def generate(
raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out")

decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info
sampling_params.nixl_params.set(pickle.dumps(decode_node_info))
sampling_params.nixl_params.set(safe_pickle_dumps(decode_node_info))

if decode_node_info.ready_kv_len == len(prompt_ids) - 1:
# 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill
Expand Down
8 changes: 4 additions & 4 deletions lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps
import websockets
import ujson as json
import socket
Expand Down Expand Up @@ -103,7 +103,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
# 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
while True:
recv_bytes = await websocket.recv()
obj = pickle.loads(recv_bytes)
obj = safe_pickle_loads(recv_bytes)
if obj[0] == ObjType.REQ:
prompt, sampling_params, multimodal_params = obj[1]
group_req_id = sampling_params.group_request_id
Expand Down Expand Up @@ -183,7 +183,7 @@ async def _get_pd_master_objs(args: StartArgs) -> Optional[Dict[int, PD_Master_O
response = await client.get(uri)
if response.status_code == 200:
base64data = response.json()["data"]
id_to_pd_master_obj = pickle.loads(base64.b64decode(base64data))
id_to_pd_master_obj = safe_pickle_loads(base64.b64decode(base64data))
return id_to_pd_master_obj
else:
logger.error(f"get pd_master_objs error {response.status_code}")
Expand Down Expand Up @@ -231,7 +231,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket: Clien

if handle_list:
load_info: dict = _get_load_info()
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)))
await websocket.send(safe_pickle_dumps((ObjType.TOKEN_PACKS, handle_list, load_info)))


# 获取节点负载信息
Expand Down
18 changes: 9 additions & 9 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import datetime
import ujson as json
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from typing import Union, List, Tuple, Dict, Optional
Expand Down Expand Up @@ -173,7 +173,7 @@ async def fetch_stream(
sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None)
sampling_params.suggested_dp_index = -1

await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))
await p_node.websocket.send_bytes(safe_pickle_dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))

while True:
await req_status.wait_to_ready()
Expand Down Expand Up @@ -210,7 +210,7 @@ async def fetch_stream(
sampling_params.suggested_dp_index = upkv_status.dp_index

await d_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams())))
safe_pickle_dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams())))
)

while True:
Expand Down Expand Up @@ -244,7 +244,7 @@ async def fetch_nixl_stream(

old_max_new_tokens = sampling_params.max_new_tokens
sampling_params.max_new_tokens = 1
await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))
await p_node.websocket.send_bytes(safe_pickle_dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))

try:
await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60)
Expand All @@ -260,7 +260,7 @@ async def fetch_nixl_stream(

sampling_params.max_new_tokens = old_max_new_tokens
await d_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams())))
safe_pickle_dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams())))
)

try:
Expand All @@ -272,9 +272,9 @@ async def fetch_nixl_stream(
# 将 decode 节点上报的当前请求使用的decode节点的信息下发给 p 节点,这样 p 节点才知道将 kv 传输给那个 d 节点。
upkv_status: NixlUpKVStatus = up_status_event.upkv_status
nixl_params: bytes = upkv_status.nixl_params
decode_node_info: NIXLDecodeNodeInfo = pickle.loads(nixl_params)
decode_node_info: NIXLDecodeNodeInfo = safe_pickle_loads(nixl_params)
await p_node.websocket.send_bytes(
pickle.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info))
safe_pickle_dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info))
)

first_token_gen = False
Expand Down Expand Up @@ -392,12 +392,12 @@ async def abort(
pass

try:
await p_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id)))
await p_node.websocket.send_bytes(safe_pickle_dumps((ObjType.ABORT, group_request_id)))
except:
pass

try:
await d_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id)))
await d_node.websocket.send_bytes(safe_pickle_dumps((ObjType.ABORT, group_request_id)))
except:
pass

Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/httpserver_for_pd_master/register_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import pickle
from lightllm.utils.pickle_utils import safe_pickle_dumps
import websockets
import socket
from lightllm.utils.net_utils import get_hostname_ip
Expand Down Expand Up @@ -31,7 +31,7 @@ async def register_loop(manager: HttpServerManagerForPDMaster):
node_id=manager.args.pd_node_id, host_ip_port=f"{manager.host_ip}:{manager.args.port}"
)

await websocket.send(pickle.dumps(pd_master_obj))
await websocket.send(safe_pickle_dumps(pd_master_obj))
logger.info(f"Sent registration pd_master obj: {pd_master_obj}")

while True:
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist
import numpy as np
import collections
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads

from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Callable, Any
Expand Down Expand Up @@ -283,7 +283,7 @@ def __init__(

# nixl decode node information
if self.shm_param.nixl_params.data_len > 0:
self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get())
self.nixl_decode_node: NIXLDecodeNodeInfo = safe_pickle_loads(self.shm_param.nixl_params.get())
else:
self.nixl_decode_node: NIXLDecodeNodeInfo = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import websockets
import inspect
import setproctitle
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps

from typing import Dict
from dataclasses import asdict
Expand Down Expand Up @@ -91,7 +91,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):
if pd_master_obj.node_id in self.id_to_handle_queue:
task_queue = self.id_to_handle_queue[pd_master_obj.node_id]
upkv_status: UpKVStatus = await task_queue.get()
await websocket.send(pickle.dumps(upkv_status))
await websocket.send(safe_pickle_dumps(upkv_status))
logger.info(f"up status: {upkv_status}")
else:
await asyncio.sleep(3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.multiprocessing as mp
import collections
import queue
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads, safe_pickle_dumps
from typing import List, Dict, Union, Deque, Optional
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_cache_mem_manager import MemoryManager
Expand Down Expand Up @@ -193,7 +193,7 @@ def dispatch_task_loop(self):
up_status = NixlUpKVStatus(
group_request_id=task.request_id,
pd_master_node_id=task.pd_master_node_id,
nixl_params=pickle.dumps(decode_node_info),
nixl_params=safe_pickle_dumps(decode_node_info),
)

self.up_status_in_queue.put(up_status)
Expand All @@ -220,7 +220,7 @@ def accept_peer_task_loop(
for remote_agent_name, _notify_list in notifies_dict.items():
for notify in _notify_list:
try:
notify_obj = pickle.loads(notify)
notify_obj = safe_pickle_loads(notify)
except:
notify_obj = None

Expand All @@ -236,7 +236,7 @@ def accept_peer_task_loop(
try:
self.transporter.send_notify_to_prefill_node(
prefill_agent_name=remote_agent_name,
notify=pickle.dumps(remote_trans_task.createRetObj()),
notify=safe_pickle_dumps(remote_trans_task.createRetObj()),
)
except BaseException as e:
logger.error(f"send notify to prefill node failed: {str(e)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading
import websockets
import inspect
import pickle
from lightllm.utils.pickle_utils import safe_pickle_dumps

from typing import Dict, Union
from dataclasses import asdict
Expand Down Expand Up @@ -88,7 +88,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):
if pd_master_obj.node_id in self.id_to_handle_queue:
task_queue = self.id_to_handle_queue[pd_master_obj.node_id]
upkv_status: Union[UpKVStatus, NixlUpKVStatus] = await task_queue.get()
await websocket.send(pickle.dumps(upkv_status))
await websocket.send(safe_pickle_dumps(upkv_status))
logger.info(f"up kv status: {upkv_status}")
else:
await asyncio.sleep(3)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pickle
from lightllm.utils.pickle_utils import safe_pickle_dumps
import copy
from dataclasses import dataclass
from collections import defaultdict
Expand Down Expand Up @@ -125,7 +125,7 @@ def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask):
new_trans_task.mem_indexes = None
self.nixl_agent.send_notif(
remote_agent.agent_name,
pickle.dumps(new_trans_task),
safe_pickle_dumps(new_trans_task),
)
return

Expand Down Expand Up @@ -165,7 +165,7 @@ def read_blocks_paged(
[trans_task.nixl_dst_page_index],
src_handle,
[trans_task.nixl_src_page_index],
pickle.dumps(notify_obj),
safe_pickle_dumps(notify_obj),
)
if not handle:
raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.multiprocessing as mp
import collections
import queue
import pickle
from lightllm.utils.pickle_utils import safe_pickle_loads
from typing import List, Dict, Union, Deque, Optional
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_cache_mem_manager import MemoryManager
Expand Down Expand Up @@ -211,7 +211,7 @@ def update_task_status_loop(
for _, _notify_list in notifies_dict.items():
for notify in _notify_list:
try:
notify_obj = pickle.loads(notify)
notify_obj = safe_pickle_loads(notify)
except:
notify_obj = None

Expand Down
Loading