Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
Expand Down
18 changes: 18 additions & 0 deletions include/infinicore/ops/paged_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
17 changes: 17 additions & 0 deletions include/infinicore/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};

void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);

} // namespace infinicore::op
4 changes: 4 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention
from infinicore.ops.paged_caching import paged_caching
from infinicore.ops.rearrange import rearrange
from infinicore.ops.squeeze import squeeze
from infinicore.ops.unsqueeze import unsqueeze
Expand Down Expand Up @@ -115,6 +117,8 @@
"from_list",
"from_numpy",
"from_torch",
"paged_caching",
"paged_attention",
"ones",
"strided_empty",
"strided_from_blob",
Expand Down
40 changes: 40 additions & 0 deletions python/infinicore/ops/paged_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def paged_attention(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
seq_lens: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.paged_attention(
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)

_infinicore.paged_attention_(
out._underlying,
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)

return out
21 changes: 21 additions & 0 deletions python/infinicore/ops/paged_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def paged_caching(
k: Tensor,
v: Tensor,
k_cache: Tensor,
v_cache: Tensor,
slot_mapping: Tensor,
):
Tensor(
_infinicore.paged_caching_(
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
slot_mapping._underlying,
)
)
return (k_cache, v_cache)
28 changes: 28 additions & 0 deletions src/infinicore/ops/paged_attention/paged_attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "infinicore/ops/paged_attention.hpp"

#include "../../utils.hpp"

namespace infinicore::op {

common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
static common::OpDispatcher<PagedAttention::schema> dispatcher_;
return dispatcher_;
};

void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
}

Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
return out;
}

void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
}

} // namespace infinicore::op
54 changes: 54 additions & 0 deletions src/infinicore/ops/paged_attention/paged_attention_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention.hpp"
#include <infiniop.h>

namespace infinicore::op::paged_attention_impl::infiniop {

thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
100, // capacity
[](infiniopPagedAttentionDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens);

auto device = context::getDevice();
auto &cache = caches.getCache(device);

auto desc_opt = cache.get(seed);
infiniopPagedAttentionDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}

static bool registered = []() {
PagedAttention::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::paged_attention_impl::infiniop
22 changes: 22 additions & 0 deletions src/infinicore/ops/paged_caching/paged_caching.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "infinicore/ops/paged_caching.hpp"

#include "../../utils.hpp"

namespace infinicore::op {

common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
static common::OpDispatcher<PagedCaching::schema> dispatcher_;
return dispatcher_;
};

void PagedCaching::execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k, v, k_cache, v_cache, slot_mapping);
infinicore::context::setDevice(k->device());
dispatcher().lookup(k->device().getType())(k, v, k_cache, v_cache, slot_mapping);
}

void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
PagedCaching::execute(k, v, k_cache, v_cache, slot_mapping);
}

} // namespace infinicore::op
50 changes: 50 additions & 0 deletions src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_caching.hpp"
#include <infiniop.h>

namespace infinicore::op::paged_caching_impl::infiniop {

thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
100, // capacity
[](infiniopPagedCachingDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyPagedCachingDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
size_t seed = hash_combine(k, v, k_cache, v_cache, slot_mapping);

auto device = context::getDevice();
auto &cache = caches.getCache(device);

auto desc_opt = cache.get(seed);
infiniopPagedCachingDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor(
context::getInfiniopHandle(device), &desc,
k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopPagedCaching(
desc, workspace->data(), workspace_size,
k->data(), v->data(), k_cache->data(), v_cache->data(), slot_mapping->data(), context::getStream()));
}

static bool registered = []() {
PagedCaching::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::paged_caching_impl::infiniop
4 changes: 4 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
Expand All @@ -28,6 +30,8 @@ inline void bind(py::module &m) {
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_paged_attention(m);
bind_paged_caching(m);
bind_rearrange(m);
bind_rms_norm(m);
bind_silu(m);
Expand Down
53 changes: 53 additions & 0 deletions src/infinicore/pybind11/ops/paged_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/paged_attention.hpp"

namespace py = pybind11;

namespace infinicore::ops {

Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::paged_attention(q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
}

void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}

op::paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
}

inline void bind_paged_attention(py::module &m) {
m.def("paged_attention",
&ops::py_paged_attention,
py::arg("q"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Paged attention of query and key cache tensors.)doc");

m.def("paged_attention_",
&ops::py_paged_attention_,
py::arg("out"),
py::arg("q"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place paged attention of query and key cache tensors.)doc");
}

} // namespace infinicore::ops
22 changes: 22 additions & 0 deletions src/infinicore/pybind11/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/paged_caching.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_paged_caching(py::module &m) {
m.def("paged_caching_",
&op::paged_caching_,
py::arg("k"),
py::arg("v"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("slot_mapping"),
R"doc(Paged caching of key and value tensors.)doc");
}

} // namespace infinicore::ops
Loading