Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions custom_ops/xpu_ops/src/ops/swap_cache_layout.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <xpu/runtime.h>
#include "paddle/extension.h"

/*
* XPU KV cache layout : layer_num * [block_num, head_num, block_size,
* head_dim] CPU pinned buf layout: [block_num, layer_num, head_num, block_size,
* head_dim]
*
* mode 0 : XPU -> CPU
* mode 1 : CPU -> XPU
*
* 地址计算
* cache_block_stride = head_num * block_size * head_dim (=
* cache_shape[1]*[2]*[3]) XPU ptr = tensor[layer_idx].data() + xpu_block_id *
* cache_block_stride CPU ptr = cpu_base_ptr
* + cpu_block_id * cache_block_stride * layer_number // block
* 维度
* + layer_idx * cache_block_stride // layer
* 维度
*/

template <typename T>
void SwapCacheImpLayout(const std::vector<paddle::Tensor>& cache_xpu_tensors,
const int64_t& cache_cpu_pointer,
const std::vector<int64_t>& cache_shape,
const std::vector<int64_t>& xpu_block_ids,
const std::vector<int64_t>& cpu_block_ids,
int mode) {
const int64_t layer_number = static_cast<int64_t>(cache_xpu_tensors.size());

// cache_block_stride = product(cache_shape[1:])
int64_t cache_block_stride = 1;
for (int i = 1; i < static_cast<int>(cache_shape.size()); i++) {
cache_block_stride *= cache_shape[i];
}

const XPUMemcpyKind copy_kind =
(mode == 0) ? XPU_DEVICE_TO_HOST : XPU_HOST_TO_DEVICE;

for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
const paddle::Tensor& cache_xpu = cache_xpu_tensors[layer_idx];
T* cache_xpu_ptr = const_cast<T*>(cache_xpu.data<T>());
auto* cache_cpu_ptr = reinterpret_cast<T*>(cache_cpu_pointer);

for (int block_idx = 0; block_idx < static_cast<int>(xpu_block_ids.size());
block_idx++) {
auto cur_xpu_block_id = xpu_block_ids[block_idx];
Comment on lines +57 to +61
auto cur_cpu_block_id = cpu_block_ids[block_idx];

auto* xpu_ptr_now = cache_xpu_ptr + cur_xpu_block_id * cache_block_stride;
auto* cpu_ptr_now = cache_cpu_ptr +
cur_cpu_block_id * cache_block_stride * layer_number +
layer_idx * cache_block_stride;

void* dst = (mode == 0) ? static_cast<void*>(cpu_ptr_now)
: static_cast<void*>(xpu_ptr_now);
void* src = (mode == 0) ? static_cast<void*>(xpu_ptr_now)
: static_cast<void*>(cpu_ptr_now);

int ret = xpu_memcpy(dst, src, cache_block_stride * sizeof(T), copy_kind);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 xpu_memcpylayer_num × block_num 双重循环中逐次同步调用。

对于大模型(32+ 层、多 block 场景),会产生大量串行 XDMA 调用,可能成为吞吐瓶颈。建议评估 XPU runtime 是否支持流式/异步 memcpy 批量提交,或在同一层内批量提交多个 block 的传输请求以提升并发度。

PD_CHECK(
ret == XPU_SUCCESS, "xpu_memcpy failed with error code: %d", ret);
}
}
}

void SwapCacheLayout(const std::vector<paddle::Tensor>& cache_xpu_tensors,
const int64_t& cache_cpu_ptrs,
const std::vector<int64_t>& cache_shape,
const std::vector<int64_t>&
gpu_block_ids, // XPU 侧 block ids(复用 gpu_block_ids
// 参数名与 GPU 版接口一致)
const std::vector<int64_t>& cpu_block_ids,
int rank,
int mode) {
xpu_set_device(rank); // used for distributed launch
PD_CHECK(cache_xpu_tensors.size() > 0, "cache_xpu_tensors must not be empty");

switch (cache_xpu_tensors[0].dtype()) {
case paddle::DataType::FLOAT16:
return SwapCacheImpLayout<paddle::float16>(cache_xpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
case paddle::DataType::BFLOAT16:
return SwapCacheImpLayout<paddle::bfloat16>(cache_xpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
case paddle::DataType::UINT8:
return SwapCacheImpLayout<uint8_t>(cache_xpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
case paddle::DataType::INT8:
return SwapCacheImpLayout<int8_t>(cache_xpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
default:
PD_THROW("Unsupported data type.");
}
}

PD_BUILD_OP(swap_cache_layout)
.Inputs({paddle::Vec("cache_xpu_tensors")})
.Attrs({
"cache_cpu_ptrs: int64_t",
"cache_shape: std::vector<int64_t>",
"gpu_block_ids: std::vector<int64_t>",
"cpu_block_ids: std::vector<int64_t>",
"rank: int",
"mode: int",
})
.Outputs({paddle::Vec("cache_dst_outs")})
.SetInplaceMap({{paddle::Vec("cache_xpu_tensors"),
paddle::Vec("cache_dst_outs")}})
.SetKernelFn(PD_KERNEL(SwapCacheLayout));
138 changes: 138 additions & 0 deletions custom_ops/xpu_ops/test/test_swap_cache_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import ctypes
import random
import time
import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
cuda_host_free,
swap_cache_layout,
)


class TestAllocCachePinned(unittest.TestCase):
"""Verify xpu_host_alloc/xpu_host_free and basic host memory access."""

def test_alloc_free(self):
size = 16 * 1024 * 1024
ptr = cuda_host_alloc(size)
self.assertNotEqual(ptr, 0, "cuda_host_alloc returned null")

try:
buf = (ctypes.c_uint8 * 4).from_address(ptr)
buf[0], buf[1], buf[2], buf[3] = 0xDE, 0xAD, 0xBE, 0xEF
self.assertEqual(list(buf), [0xDE, 0xAD, 0xBE, 0xEF])
finally:
cuda_host_free(ptr)


class TestSwapCacheLayout(unittest.TestCase):
layer_num = 8
block_num = 128
head_num = 4
block_size = 16
head_dim = 64
swap_block_num = 16

def setUp(self):
self.cache_shape = [self.block_num, self.head_num, self.block_size, self.head_dim]
self.block_stride = self.head_num * self.block_size * self.head_dim
self.block_bytes = self.block_stride * 2

buffer_total_bytes = self.swap_block_num * self.layer_num * self.block_bytes
self.cpu_buffer = cuda_host_alloc(buffer_total_bytes)

self.xpu_block_ids = random.sample(range(self.block_num), self.swap_block_num)
self.cpu_block_ids = list(range(self.swap_block_num))

def tearDown(self):
cuda_host_free(self.cpu_buffer)

def _make_cache(self, fill_value=None):
cache = []
for layer_idx in range(self.layer_num):
value = float(layer_idx) if fill_value is None else float(fill_value)
cache.append(paddle.full(self.cache_shape, fill_value=value, dtype=paddle.float16))
paddle.device.synchronize()
return cache

def test_roundtrip(self):
src = self._make_cache()
dst = self._make_cache(fill_value=-1)

swap_cache_layout(
src,
self.cpu_buffer,
self.cache_shape,
self.xpu_block_ids,
self.cpu_block_ids,
0,
0,
)
swap_cache_layout(
dst,
self.cpu_buffer,
self.cache_shape,
self.xpu_block_ids,
self.cpu_block_ids,
0,
1,
)

for layer_idx in range(self.layer_num):
got = dst[layer_idx][self.xpu_block_ids].numpy()
expected = np.full_like(got, float(layer_idx))
self.assertTrue(
np.allclose(got, expected, atol=1e-2),
f"roundtrip mismatch at layer={layer_idx}",
)

def _run_and_report(self, mode, label):
cache = self._make_cache()
total_gb = self.swap_block_num * self.layer_num * self.block_bytes / 1073741824

start = time.time()
swap_cache_layout(
cache,
self.cpu_buffer,
self.cache_shape,
self.xpu_block_ids,
self.cpu_block_ids,
0,
mode,
)
paddle.device.synchronize()
cost_time = time.time() - start
print(
f"swap cache layout ({label}), total_gb: {total_gb:.6f}GB, "
f"cost_time: {cost_time:.6f}s, speed: {total_gb / cost_time:.6f}GB/s"
)

def test_performance(self):
for _ in range(3):
Comment on lines +121 to +131
self._run_and_report(0, "device to host")
for _ in range(3):
self._run_and_report(1, "host to device")


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion fastdeploy/cache_manager/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def swap_cache_layout(*args, **kwargs):
set_data_ipc,
share_external_data,
swap_cache_all_layers,
swap_cache_layout,
)

unset_data_ipc = None
swap_cache_layout = None
memory_allocated = paddle.device.xpu.memory_allocated

def get_data_ptr_ipc(*args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def create() -> "MooncakeStoreConfig":
rdma_devices = config.get("rdma_devices", "")
master_server_addr = config.get("master_server_addr")

if rdma_devices == "" and current_platform.is_cuda():
if rdma_devices == "" and (current_platform.is_cuda() or current_platform.is_xpu()):
# FIXME: use auto-select NICs in MooncakeStore will raise error and roll back to using TCP
rdma_devices = get_rdma_nics()
logger.info(f"No RDMA devices specified, defaulting to all available devices: {rdma_devices}")
Expand Down
Loading