Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
Expand Down
20 changes: 20 additions & 0 deletions include/infinicore/ops/add_rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <utility>

namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};

// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "infiniop/handle.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
Expand Down
32 changes: 32 additions & 0 deletions include/infiniop/ops/add_rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef __INFINIOP_ADD_RMS_NORM_API_H__
#define __INFINIOP_ADD_RMS_NORM_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t;

__C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc);

__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream);

__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);

#endif
3 changes: 3 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
uint8,
)
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
Expand Down Expand Up @@ -102,6 +103,8 @@
"uint8",
# Operations.
"add",
"add_rms_norm",
"add_rms_norm_",
"attention",
"matmul",
"mul",
Expand Down
47 changes: 47 additions & 0 deletions python/infinicore/ops/add_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
"""
Fused Add and RMS Normalization.

Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation

Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if out is None:
result = _infinicore.add_rms_norm(
a._underlying, b._underlying, weight._underlying, epsilon
)
return (Tensor(result[0]), Tensor(result[1]))

y, residual_out = out
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return (y, residual_out)


def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
29 changes: 29 additions & 0 deletions src/infinicore/ops/add_rms_norm/add_rms_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "infinicore/ops/add_rms_norm.hpp"

#include "../../utils.hpp"

namespace infinicore::op {

common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};

void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
}

std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out);
}

void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
}

} // namespace infinicore::op
50 changes: 50 additions & 0 deletions src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

namespace infinicore::op::add_rms_norm_impl::infiniop {

thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);

auto device = context::getDevice();
auto &cache = caches.getCache(device);

auto desc_opt = cache.get(seed);
infiniopAddRMSNormDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
}

static bool registered = []() {
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::add_rms_norm_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <pybind11/pybind11.h>

#include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
Expand All @@ -22,6 +23,7 @@ namespace infinicore::ops {

inline void bind(py::module &m) {
bind_add(m);
bind_add_rms_norm(m);
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
Expand Down
51 changes: 51 additions & 0 deletions src/infinicore/pybind11/ops/add_rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/add_rms_norm.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_add_rms_norm(py::module &m) {
m.def("add_rms_norm",
&op::add_rms_norm,
py::arg("a"),
py::arg("b"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(Fused Add and RMS Normalization.

Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5

Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
)doc");

m.def("add_rms_norm_",
&op::add_rms_norm_,
py::arg("y"),
py::arg("residual_out"),
py::arg("a"),
py::arg("b"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(In-place Fused Add and RMS Normalization.

Args:
y: Output tensor for normalized result
residual_out: Output tensor for add result (a + b) before normalization
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
)doc");
}

} // namespace infinicore::ops
53 changes: 53 additions & 0 deletions src/infiniop/ops/add_rms_norm/add_rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H

#include "../../operator.h"
#include "info.h"

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}

#endif // ADD_RMS_NORM_H
Loading