Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lightllm/distributed/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Optional, Union, Dict, Deque, Tuple, Any
from collections import deque
import logging
import ujson as json

# ===================== import region =====================
import torch
Expand Down Expand Up @@ -85,7 +86,7 @@ def send_obj(self, obj: Any):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{self.dest_id}/{self.send_dst_counter}"
self.store.set(key, pickle.dumps(obj))
self.store.set(key, json.dumps(obj).encode())
self.send_dst_counter += 1
self.entries.append((key, time.time()))

Expand All @@ -102,7 +103,7 @@ def expire_data(self):

def recv_obj(self) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}"))
obj = json.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}").decode())
self.recv_src_counter += 1
return obj

Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async def register_and_keep_alive(websocket: WebSocket):
while True:
# 等待接收消息,设置超时为10秒
data = await websocket.receive_bytes()
obj = pickle.loads(data)
obj = json.loads(data.decode())
await g_objs.httpserver_manager.put_to_handle_queue(obj)

except (WebSocketDisconnect, Exception, RuntimeError) as e:
Expand All @@ -328,7 +328,7 @@ async def kv_move_status(websocket: WebSocket):
while True:
# 等待接收消息,设置超时为10秒
data = await websocket.receive_bytes()
upkv_status = pickle.loads(data)
upkv_status = json.loads(data.decode())
logger.info(f"received upkv_status {upkv_status} from {(client_ip, client_port)}")
await g_objs.httpserver_manager.update_req_status(upkv_status)
except (WebSocketDisconnect, Exception, RuntimeError) as e:
Expand Down
5 changes: 3 additions & 2 deletions lightllm/server/config_server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import pickle
import setproctitle
import ujson as json
import multiprocessing as mp
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
from threading import Lock
Expand Down Expand Up @@ -53,7 +54,7 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
client_ip, client_port = websocket.client
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
registered_pd_master_obj: PD_Master_Obj = pickle.loads(await websocket.receive_bytes())
registered_pd_master_obj: PD_Master_Obj = json.loads((await websocket.receive_bytes()).decode())
logger.info(f"received registered_pd_master_obj {registered_pd_master_obj}")
with registered_pd_master_obj_lock:
registered_pd_master_objs[registered_pd_master_obj.node_id] = registered_pd_master_obj
Expand All @@ -75,7 +76,7 @@ async def websocket_endpoint(websocket: WebSocket):
@app.get("/registered_objects")
async def get_registered_objects():
with registered_pd_master_obj_lock:
serialized_data = pickle.dumps(registered_pd_master_objs)
serialized_data = json.dumps(registered_pd_master_objs).encode()
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
return {"data": base64_encoded}

Expand Down
5 changes: 3 additions & 2 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hashlib
import datetime
import pickle
import ujson as json
from frozendict import frozendict

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
Expand Down Expand Up @@ -308,7 +309,7 @@ async def generate(
f"nixl prefill node upload group_req_id {group_request_id} prompt ids len : {len(prompt_ids)}"
)
await nixl_pd_upload_websocket.send(
pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids))
json.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids)).encode()
)
try:
await asyncio.wait_for(nixl_pd_event.wait(), timeout=80)
Expand All @@ -317,7 +318,7 @@ async def generate(
raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out")

decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info
sampling_params.nixl_params.set(pickle.dumps(decode_node_info))
sampling_params.nixl_params.set(json.dumps(decode_node_info).encode())

if decode_node_info.ready_kv_len == len(prompt_ids) - 1:
# 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
# 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
while True:
recv_bytes = await websocket.recv()
obj = pickle.loads(recv_bytes)
obj = json.loads(recv_bytes.decode())
if obj[0] == ObjType.REQ:
prompt, sampling_params, multimodal_params = obj[1]
group_req_id = sampling_params.group_request_id
Expand Down Expand Up @@ -183,7 +183,7 @@ async def _get_pd_master_objs(args: StartArgs) -> Optional[Dict[int, PD_Master_O
response = await client.get(uri)
if response.status_code == 200:
base64data = response.json()["data"]
id_to_pd_master_obj = pickle.loads(base64.b64decode(base64data))
id_to_pd_master_obj = json.loads(base64.b64decode(base64data).decode())
return id_to_pd_master_obj
else:
logger.error(f"get pd_master_objs error {response.status_code}")
Expand Down Expand Up @@ -231,7 +231,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket: Clien

if handle_list:
load_info: dict = _get_load_info()
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)))
await websocket.send(json.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)).encode())


# 获取节点负载信息
Expand Down
20 changes: 12 additions & 8 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ async def fetch_stream(
sampling_params.move_kv_to_decode_node.initialize(decode_node_dict if old_max_new_tokens != 1 else None)
sampling_params.suggested_dp_index = -1

await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))
await p_node.websocket.send_bytes(
json.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))).encode()
)

while True:
await req_status.wait_to_ready()
Expand Down Expand Up @@ -210,7 +212,7 @@ async def fetch_stream(
sampling_params.suggested_dp_index = upkv_status.dp_index

await d_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams())))
json.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))).encode()
)

while True:
Expand Down Expand Up @@ -244,7 +246,9 @@ async def fetch_nixl_stream(

old_max_new_tokens = sampling_params.max_new_tokens
sampling_params.max_new_tokens = 1
await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))
await p_node.websocket.send_bytes(
json.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))).encode()
)

try:
await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60)
Expand All @@ -260,7 +264,7 @@ async def fetch_nixl_stream(

sampling_params.max_new_tokens = old_max_new_tokens
await d_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams())))
json.dumps((ObjType.REQ, (prompt_ids, sampling_params, MultimodalParams()))).encode()
)

try:
Expand All @@ -272,9 +276,9 @@ async def fetch_nixl_stream(
# 将 decode 节点上报的当前请求使用的decode节点的信息下发给 p 节点,这样 p 节点才知道将 kv 传输给那个 d 节点。
upkv_status: NixlUpKVStatus = up_status_event.upkv_status
nixl_params: bytes = upkv_status.nixl_params
decode_node_info: NIXLDecodeNodeInfo = pickle.loads(nixl_params)
decode_node_info: NIXLDecodeNodeInfo = json.loads(nixl_params.decode())
await p_node.websocket.send_bytes(
pickle.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info))
json.dumps((ObjType.NIXL_REQ_DECODE_NODE_INFO, group_request_id, decode_node_info)).encode()
)

first_token_gen = False
Expand Down Expand Up @@ -392,12 +396,12 @@ async def abort(
pass

try:
await p_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id)))
await p_node.websocket.send_bytes(json.dumps((ObjType.ABORT, group_request_id)).encode())
except:
pass

try:
await d_node.websocket.send_bytes(pickle.dumps((ObjType.ABORT, group_request_id)))
await d_node.websocket.send_bytes(json.dumps((ObjType.ABORT, group_request_id)).encode())
except:
pass

Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/httpserver_for_pd_master/register_loop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import pickle
import websockets
import socket
import ujson as json
from lightllm.utils.net_utils import get_hostname_ip
from lightllm.utils.log_utils import init_logger
from lightllm.server.httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
Expand Down Expand Up @@ -31,7 +31,7 @@ async def register_loop(manager: HttpServerManagerForPDMaster):
node_id=manager.args.pd_node_id, host_ip_port=f"{manager.host_ip}:{manager.args.port}"
)

await websocket.send(pickle.dumps(pd_master_obj))
await websocket.send(json.dumps(pd_master_obj).encode())
logger.info(f"Sent registration pd_master obj: {pd_master_obj}")

while True:
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist
import numpy as np
import collections
import pickle
import ujson as json

from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Callable, Any
Expand Down Expand Up @@ -283,7 +283,7 @@ def __init__(

# nixl decode node information
if self.shm_param.nixl_params.data_len > 0:
self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get())
self.nixl_decode_node: NIXLDecodeNodeInfo = json.loads(self.shm_param.nixl_params.get().decode())
else:
self.nixl_decode_node: NIXLDecodeNodeInfo = None

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
import json
import ujson as json
import asyncio
import threading
import websockets
Expand Down Expand Up @@ -91,7 +91,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):
if pd_master_obj.node_id in self.id_to_handle_queue:
task_queue = self.id_to_handle_queue[pd_master_obj.node_id]
upkv_status: UpKVStatus = await task_queue.get()
await websocket.send(pickle.dumps(upkv_status))
await websocket.send(json.dumps(upkv_status).encode())
logger.info(f"up status: {upkv_status}")
else:
await asyncio.sleep(3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.multiprocessing as mp
import collections
import queue
import pickle
import ujson as json
from typing import List, Dict, Union, Deque, Optional
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_cache_mem_manager import MemoryManager
Expand Down Expand Up @@ -193,7 +193,7 @@ def dispatch_task_loop(self):
up_status = NixlUpKVStatus(
group_request_id=task.request_id,
pd_master_node_id=task.pd_master_node_id,
nixl_params=pickle.dumps(decode_node_info),
nixl_params=json.dumps(decode_node_info).encode(),
)

self.up_status_in_queue.put(up_status)
Expand All @@ -220,7 +220,7 @@ def accept_peer_task_loop(
for remote_agent_name, _notify_list in notifies_dict.items():
for notify in _notify_list:
try:
notify_obj = pickle.loads(notify)
notify_obj = json.loads(notify.decode())
except:
notify_obj = None

Expand All @@ -236,7 +236,7 @@ def accept_peer_task_loop(
try:
self.transporter.send_notify_to_prefill_node(
prefill_agent_name=remote_agent_name,
notify=pickle.dumps(remote_trans_task.createRetObj()),
notify=json.dumps(remote_trans_task.createRetObj()).encode(),
)
except BaseException as e:
logger.error(f"send notify to prefill node failed: {str(e)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):
if pd_master_obj.node_id in self.id_to_handle_queue:
task_queue = self.id_to_handle_queue[pd_master_obj.node_id]
upkv_status: Union[UpKVStatus, NixlUpKVStatus] = await task_queue.get()
await websocket.send(pickle.dumps(upkv_status))
await websocket.send(json.dumps(upkv_status).encode())
logger.info(f"up kv status: {upkv_status}")
else:
await asyncio.sleep(3)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import copy
import ujson as json
from dataclasses import dataclass
from collections import defaultdict
from typing import Dict, List, Any, Optional, Tuple
Expand Down Expand Up @@ -125,7 +126,7 @@ def send_readtask_to_decode_node(self, trans_task: NIXLChunckedTransTask):
new_trans_task.mem_indexes = None
self.nixl_agent.send_notif(
remote_agent.agent_name,
pickle.dumps(new_trans_task),
json.dumps(new_trans_task).encode(),
)
return

Expand Down Expand Up @@ -165,7 +166,7 @@ def read_blocks_paged(
[trans_task.nixl_dst_page_index],
src_handle,
[trans_task.nixl_src_page_index],
pickle.dumps(notify_obj),
json.dumps(notify_obj).encode(),
)
if not handle:
raise RuntimeError(f"make_prepped_xfer failed for task: {trans_task.to_str()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.multiprocessing as mp
import collections
import queue
import pickle
import ujson as json
from typing import List, Dict, Union, Deque, Optional
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_cache_mem_manager import MemoryManager
Expand Down Expand Up @@ -211,7 +211,7 @@ def update_task_status_loop(
for _, _notify_list in notifies_dict.items():
for notify in _notify_list:
try:
notify_obj = pickle.loads(notify)
notify_obj = json.loads(notify.decode())
except:
notify_obj = None

Expand Down