Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions hls4ml/backends/vivado/passes/merge_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

merge_config_template = """struct config{index} : nnet::merge_config {{
static const unsigned n_elem = {n_elem};
static const unsigned n_elem1 = {n_elem1};
static const unsigned n_elem2 = {n_elem2};
static const unsigned reuse_factor = {reuse};
}};\n"""

Expand All @@ -21,8 +23,12 @@ def __init__(self):

def format(self, node):
params = self._default_config_params(node)
params['n_elem'] = node.get_input_variable(node.inputs[0]).size_cpp()

params['n_elem1'] = node.get_input_variable(node.inputs[0]).size_cpp()
params['n_elem2'] = node.get_input_variable(node.inputs[1]).size_cpp()
params['n_elem'] = max(params['n_elem1'], params['n_elem2'])
io_type = node.model.config.get_config_value('IOType')
if io_type != 'io_parallel':
assert params['n_elem1'] == params['n_elem2'], 'broadcasting merge not supported non-io_parallel'
return self.template.format(**params)


Expand Down
31 changes: 18 additions & 13 deletions hls4ml/templates/vivado/nnet_utils/nnet_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
namespace nnet {

struct merge_config {
static const unsigned n_elem = 10;
static const unsigned n_elem1 = 10;
static const unsigned n_elem2 = 10;
static const unsigned reuse_factor = 1;
};

Expand All @@ -34,56 +35,60 @@ struct concat_config {
};

template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
void add(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
void add(input1_T data1[CONFIG_T::n_elem1], input2_T data2[CONFIG_T::n_elem2], res_T res[CONFIG_T::n_elem]) {
#pragma HLS PIPELINE

for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
res[ii] = data1[ii] + data2[ii];
res[ii] = data1[ii % CONFIG_T::n_elem1] + data2[ii % CONFIG_T::n_elem2];
}
}

template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
void subtract(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
void subtract(input1_T data1[CONFIG_T::n_elem1], input2_T data2[CONFIG_T::n_elem2], res_T res[CONFIG_T::n_elem]) {
#pragma HLS PIPELINE

for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
res[ii] = data1[ii] - data2[ii];
res[ii] = data1[ii % CONFIG_T::n_elem1] - data2[ii % CONFIG_T::n_elem2];
}
}

template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
void multiply(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
void multiply(input1_T data1[CONFIG_T::n_elem1], input2_T data2[CONFIG_T::n_elem2], res_T res[CONFIG_T::n_elem]) {
#pragma HLS PIPELINE

for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
res[ii] = data1[ii] * data2[ii];
res[ii] = data1[ii % CONFIG_T::n_elem1] * data2[ii % CONFIG_T::n_elem2];
}
}

template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
void average(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
void average(input1_T data1[CONFIG_T::n_elem1], input2_T data2[CONFIG_T::n_elem2], res_T res[CONFIG_T::n_elem]) {
#pragma HLS PIPELINE

for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
res[ii] = (data1[ii] + data2[ii]) * ap_ufixed<1, 0>(0.5);
res[ii] = (data1[ii % CONFIG_T::n_elem1] + data2[ii % CONFIG_T::n_elem2]) * ap_ufixed<1, 0>(0.5);
}
}

template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
void maximum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
void maximum(input1_T data1[CONFIG_T::n_elem1], input2_T data2[CONFIG_T::n_elem2], res_T res[CONFIG_T::n_elem]) {
#pragma HLS PIPELINE

for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
res[ii] = (data1[ii] > data2[ii]) ? static_cast<res_T>(data1[ii]) : static_cast<res_T>(data2[ii]);
res[ii] = (data1[ii % CONFIG_T::n_elem1] > data2[ii % CONFIG_T::n_elem2])
? static_cast<res_T>(data1[ii % CONFIG_T::n_elem1])
: static_cast<res_T>(data2[ii % CONFIG_T::n_elem2]);
}
}

template <class input1_T, class input2_T, class res_T, typename CONFIG_T>
void minimum(input1_T data1[CONFIG_T::n_elem], input2_T data2[CONFIG_T::n_elem], res_T res[CONFIG_T::n_elem]) {
void minimum(input1_T data1[CONFIG_T::n_elem1], input2_T data2[CONFIG_T::n_elem2], res_T res[CONFIG_T::n_elem]) {
#pragma HLS PIPELINE

for (int ii = 0; ii < CONFIG_T::n_elem; ii++) {
res[ii] = (data1[ii] < data2[ii]) ? static_cast<res_T>(data1[ii]) : static_cast<res_T>(data2[ii]);
res[ii] = (data1[ii % CONFIG_T::n_elem1] < data2[ii % CONFIG_T::n_elem2])
? static_cast<res_T>(data1[ii % CONFIG_T::n_elem1])
: static_cast<res_T>(data2[ii % CONFIG_T::n_elem2]);
}
}

Expand Down
Loading