Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-platform-generic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ jobs:
testFloatSoftmax
testFloatTranspose
testFloatMul
testFloatPowScalar
testFloatPowVector
testFloatSqrt
testFloatRMSNorm
Quant
Dequant
QuantizedLinear
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package-lock.json
.mypy_cache
node_modules

.venv/*

compile_commands.json

docs/_autosummary
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
## Unreleased (Planned Release Target: v0.2.1)

### List of Pull Requests
- Support for RMSNorm (Pow and Sqrt operators) [#136](https://github.com/pulp-platform/Deeploy/pull/136)
- Demo TinyViT compatibility with tiled Siracusa [#124](https://github.com/pulp-platform/Deeploy/pull/124)
- TinyViT on non-tiled Siracusa [#117](https://github.com/pulp-platform/Deeploy/pull/117)
- Support Fully Asynchronous DMAs [#114](https://github.com/pulp-platform/Deeploy/pull/114)
Expand All @@ -26,6 +27,8 @@ This file contains the changelog for the Deeploy project. The changelog is divid
- Fix bias hoisting in generic GEMM with no bias [#126](https://github.com/pulp-platform/Deeploy/pull/126)

### Added
- Support for RMSNorm operation via operator decomposition.
- Added `Pow` (Power) and `Sqrt` (Square Root) operation support (Parsers, Layers, Bindings, Templates, and FP32 Kernels) for the Generic platform.
- Support for input tiling for PULP FP regular and DW conv 2D.
- CI tests for tiled Siracusa FP regular and DW conv 2D, with and without bias, for skip connections, and for the demo version of TinyViT.
- Documentation for PULP FP regular and DW conv 2D and MatMul tile constraints.
Expand Down
20 changes: 15 additions & 5 deletions Deeploy/Targets/Generic/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \
FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \
FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \
IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \
PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \
RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \
iRMSNormTemplate, iSoftmaxTemplate
FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, \
GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, \
MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
Expand Down Expand Up @@ -118,6 +118,16 @@
BasicTransformer)
]

BasicPowBindings = [
NodeBinding(DummyChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
FloatPowTemplate.referenceTemplate, BasicTransformer),
]

BasicSqrtBindings = [
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatSqrtTemplate.referenceTemplate,
BasicTransformer),
]

BasicDivBindings = [
NodeBinding(DivChecker([PointerClass(int32_t), PointerClass(int32_t)], [PointerClass(int32_t)]),
IntegerDivTemplate.referenceTemplate, BasicTransformer)
Expand Down
12 changes: 12 additions & 0 deletions Deeploy/Targets/Generic/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ def computeOps(self):
return matmul + rqs


class PowLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)


class SqrtLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)


class DivLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
Expand Down
51 changes: 50 additions & 1 deletion Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import onnx_graphsurgeon as gs

from Deeploy.DeeployTypes import NetworkContext, NodeParser, VariableBuffer
from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext, NodeParser, VariableBuffer


class ConcatParser(NodeParser):
Expand Down Expand Up @@ -1964,6 +1964,32 @@ def parseNodeCtxt(self,
return ctxt, True


class PowParser(NodeParser):

def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

# Lookup both inputs (data and exponent)
data_in = ctxt.lookup(node.inputs[0].name)
exponent_tensor = ctxt.lookup(node.inputs[1].name)
data_out = ctxt.lookup(node.outputs[0].name)

self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['exponent'] = exponent_tensor.name
self.operatorRepresentation['data_out'] = data_out.name
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, True


class DivParser(NodeParser):

def __init__(self):
Expand Down Expand Up @@ -2747,3 +2773,26 @@ def parseNodeCtxt(self,
"ch_im_out"] * self.operatorRepresentation["dim_im_out_y"]
return newCtxt, True
return ctxt, False


class SqrtParser(NodeParser):

def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

data_in = ctxt.lookup(node.inputs[0].name)
data_out = ctxt.lookup(node.outputs[0].name)

self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['data_out'] = data_out.name
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, True
23 changes: 14 additions & 9 deletions Deeploy/Targets/Generic/Platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@
BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, BasicGatherBindings, \
BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, \
BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, \
BasicPad1DBindings, BasicPad2DBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, \
BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, \
BasicSliceBindings, BasicSoftmaxBindings, BasicTransposeBindings, DummyBinding
BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, \
BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, \
BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicTransposeBindings, \
DummyBinding
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, ConcatLayer, ConvLayer, \
ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, \
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \
ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, SoftmaxLayer, \
TransposeLayer
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, \
SoftmaxLayer, SqrtLayer, TransposeLayer
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, ConcatParser, ConvTranspose1DParser, \
DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, \
GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, \
IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, \
Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, \
ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, TransposeParser, UnsqueezeParser, \
iLayerNormParser, iSoftmaxParser
Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \
RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, SqrtParser, \
TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser
from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \
ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \
Expand All @@ -52,6 +53,8 @@
MaxPoolMapper = NodeMapper(GenericMaxPool2DParser(), BasicMaxPool2DBindings)
MaxPool1DMapper = NodeMapper(MaxPool1DParser(), BasicMaxPool1DBindings)
MulMapper = NodeMapper(MulParser(), BasicMulBindings)
PowMapper = NodeMapper(PowParser(), BasicPowBindings)
SqrtMapper = NodeMapper(SqrtParser(), BasicSqrtBindings)
Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings)
Pad2DMapper = NodeMapper(Pad2DParser(), BasicPad2DBindings)
ReduceMeanMapper = NodeMapper(ReduceMeanParser(), BasicReduceMeanBindings)
Expand Down Expand Up @@ -98,6 +101,8 @@
'MatMulInteger': MatMulLayer([MatMulMapper]),
'MaxPool': MaxPoolLayer([MaxPool1DMapper, MaxPoolMapper]),
'Mul': MulLayer([MulMapper]),
'Pow': PowLayer([PowMapper]),
'Sqrt': SqrtLayer([SqrtMapper]),
'Pad': PadLayer([Pad1DMapper, Pad2DMapper]),
'ReduceMean': ReduceMeanLayer([ReduceMeanMapper]),
'ReduceSum': ReduceSumLayer([ReduceSumMapper]),
Expand Down
59 changes: 59 additions & 0 deletions Deeploy/Targets/Generic/Templates/FloatPowTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
#
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Tuple

import numpy as np

from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation


class _PowTemplate(NodeTemplate):

def alignToContext(self, ctxt: NetworkContext,
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
# Get input and output tensors
data_in = ctxt.lookup(operatorRepresentation['data_in'])
exponent = ctxt.lookup(operatorRepresentation['exponent'])
data_out = ctxt.lookup(operatorRepresentation['data_out'])

# Get data type (fp32)
data_type = data_in._type.typeName
operatorRepresentation['data_type'] = data_type

# Get type width dynamically (e.g., 32, 64)
type_width = data_in._type.referencedType.typeWidth
operatorRepresentation['type_width'] = type_width

# Calculate size
input_size = int(np.prod(data_in.shape))
exponent_size = int(np.prod(exponent.shape))
operatorRepresentation['size'] = input_size

# Check if exponent is scalar (broadcasting)
if exponent_size == 1:
operatorRepresentation['is_scalar'] = True
# Get the full variable name with prefix
exponent_name = operatorRepresentation['exponent']
operatorRepresentation['exponent_scalar'] = f"DeeployNetwork_{exponent_name}[0]"
else:
# Since currently the kernel only supports equally sized base-exponent data,
# for non-scalar, let's add a size check here (length of data_in should be equal to exponent length).
if input_size != exponent_size:
raise ValueError(f"Pow operator mismatch: input size ({input_size}) "
f"must equal exponent size ({exponent_size}) for non-scalar exponents.")

operatorRepresentation['is_scalar'] = False
operatorRepresentation['exponent_scalar'] = "NULL"

return ctxt, operatorRepresentation, []


referenceTemplate = _PowTemplate("""
// Pow (Name: ${nodeName}, Op: ${nodeOp})
% if is_scalar:
Pow_fp${type_width}_scalar_fp${type_width}(${data_in}, ${exponent_scalar}, ${data_out}, ${size});
% else:
Pow_fp${type_width}_fp${type_width}_fp${type_width}(${data_in}, ${exponent}, ${data_out}, ${size});
% endif
""")
35 changes: 35 additions & 0 deletions Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
#
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Tuple

import numpy as np

from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation


class _SqrtTemplate(NodeTemplate):

def alignToContext(self, ctxt: NetworkContext,
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
# Get input and output tensors
data_in = ctxt.lookup(operatorRepresentation['data_in'])
data_out = ctxt.lookup(operatorRepresentation['data_out'])

# Get data type (fp32)
data_type = data_in._type.typeName
operatorRepresentation['data_type'] = data_type

type_width = data_in._type.referencedType.typeWidth
operatorRepresentation['type_width'] = type_width

# Calculate size
operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, operatorRepresentation, []


referenceTemplate = _SqrtTemplate("""
// Sqrt (Name: ${nodeName}, Op: ${nodeOp})
Sqrt_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size});
""")
Binary file added DeeployTest/Tests/testFloatPowScalar/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatPowScalar/network.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
23 changes: 23 additions & 0 deletions DeeployTest/Tests/testFloatPowVector/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

deeploy_test_generator:·
3
data_in
exponentdata_outPow_Vector_Test"Powtest_float_pow_vectorZ!
data_in




Z"
exponent




b"
data_out




B
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatRMSNorm/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatRMSNorm/network.onnx
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatRMSNorm/outputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatSqrt/inputs.npz
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatSqrt/network.onnx
Binary file not shown.
Binary file added DeeployTest/Tests/testFloatSqrt/outputs.npz
Binary file not shown.
2 changes: 2 additions & 0 deletions TargetLibraries/Generic/inc/DeeployBasicMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@
#include "kernel/MatMul.h"
#include "kernel/MaxPool.h"
#include "kernel/MaxPool1d.h"
#include "kernel/Pow.h"
#include "kernel/RMSNorm.h"
#include "kernel/RQDiv.h"
#include "kernel/RQGELU.h"
#include "kernel/RQHardswish.h"
#include "kernel/Relu.h"
#include "kernel/RequantShift.h"
#include "kernel/Softmax.h"
#include "kernel/Sqrt.h"

#endif //__DEEPLOY_BASIC_MATH_HEADER_
24 changes: 24 additions & 0 deletions TargetLibraries/Generic/inc/kernel/Pow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
*
* SPDX-License-Identifier: Apache-2.0
*/

/*
* This file implements the element-wise binary power operation.
*/

#ifndef __DEEPLOY_MATH_POW_KERNEL_HEADER_
#define __DEEPLOY_MATH_POW_KERNEL_HEADER_

#include "DeeployBasicMath.h"

void Pow_fp32_fp32_fp32(const float32_t *__restrict__ data_in,
const float32_t *__restrict__ exponent,
float32_t *__restrict__ data_out, int32_t size);

void Pow_fp32_scalar_fp32(const float32_t *__restrict__ data_in,
float32_t exponent, float32_t *__restrict__ data_out,
int32_t size);

#endif
22 changes: 22 additions & 0 deletions TargetLibraries/Generic/inc/kernel/Sqrt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-FileCopyrightText: 2020 ETH Zurich and University of Bologna
*
* SPDX-License-Identifier: Apache-2.0
*/

#ifndef __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_
#define __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_

#include "DeeployBasicMath.h"

/*
* Square root operation - computes sqrt for each element
*/

/******************************************************************************/
/* Sqrt */
/******************************************************************************/

void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size);

#endif //__DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_
Loading
Loading