Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tmva/sofie/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
TMVA/ROperator_Einsum.hxx
TMVA/ROperator_Random.hxx
TMVA/ROperator_ScatterElements.hxx
TMVA/ROperator_ScatterND.hxx
TMVA/ROperator_Gather.hxx
TMVA/ROperator_GatherND.hxx
TMVA/ROperator_NonZero.hxx
Expand Down
192 changes: 192 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#ifndef TMVA_SOFIE_ROPERATOR_ScatterND
#define TMVA_SOFIE_ROPERATOR_ScatterND

#include "TMVA/SOFIE_common.hxx"
#include "TMVA/ROperator.hxx"
#include "TMVA/RModel.hxx"

#include <sstream>
#include <stdexcept>
#include <string>

namespace TMVA{
namespace Experimental{
namespace SOFIE{

Comment on lines +12 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
namespace TMVA{
namespace Experimental{
namespace SOFIE{
namespace TMVA::Experimental::SOFIE {

Same here maybe.

class ROperator_ScatterND final : public ROperator
{
private:


std::string fNX;
std::string fNI;
std::string fNU;
std::string fNY;
std::string fReduction;

std::vector<Dim> fShapeX;
std::vector<Dim> fShapeI;
std::vector<Dim> fShapeY;


std::vector<int64_t> fIndices; // indices vector in case they are known at initialization

std::string fType;


public:
ROperator_ScatterND(){}
ROperator_ScatterND(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY,
std::string reduction):
fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)),
fNY(UTILITY::Clean_name(nameY)), fReduction(reduction)
{
fInputTensorNames = { fNX, fNI, fNU };
fOutputTensorNames = { fNY };
}

void Initialize(RModel& model) override {

// input must be a graph input, or already initialized intermediate tensor
if (!model.CheckIfTensorAlreadyExist(fNX)){
throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNX + "is not found in model");
}
if (!model.CheckIfTensorAlreadyExist(fNI)) {
throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNI + "is not found in model");
}
if (!model.CheckIfTensorAlreadyExist(fNU)) {
throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNU + "is not found in model");
}
//tbd check for constant tensors

fShapeX = model.GetDimTensorShape(fNX);
fShapeI = model.GetDimTensorShape(fNI);
auto shapeU = model.GetDimTensorShape(fNU);

// Validate inputs if fShapeI last is not dynamic

//if (!model.IsDynamicTensor(fNI)) {
const size_t r = fShapeX.size(); // rank of data
const size_t q = fShapeI.size(); // rank of indices
if (!(fShapeI.back().isParam) ) {
const size_t k = fShapeI.back().dim; // index depth

if (k > r)
throw std::invalid_argument(
"ScatterND: last dim of indices (" + std::to_string(k) +
") must be <= rank of data (" + std::to_string(r) + ")");

// Expected updates rank = q - 1 + r - k
int64_t expected_updates_rank = q - 1 + r - k;
if ((int64_t) shapeU.size() != expected_updates_rank)
throw std::invalid_argument("ScatterND: updates rank mismatch");
} else {
// Assumption is that last dimension of index shape is known (is not dynamic)
throw std::runtime_error("TMVA SOFIE ScatterND : Index_shape(-1) is not known. This case is not supported");
}

// output shape is equal to input shape
fShapeY = fShapeX;

model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
if (model.Verbose()) {
std::cout << "ScatterElements: input: " << ConvertDimShapeToString(fShapeX)
<< " indices " << ConvertDimShapeToString(fShapeI)
<< " update " << ConvertDimShapeToString(shapeU);
std::cout << "\t----> " << ConvertDimShapeToString(fShapeY) << std::endl;
}
}

std::string Generate(std::string opName) override {
if (fIsOutputConstant) {
// no code to generate here for constant output. Tensor output is defined in Session constructor
return "//---------------------------------------\n";
}
opName = "op_" + opName;
std::stringstream out;
out << "//--------- ScatterND " << opName << " --> " << ConvertDimShapeToString(fShapeY) << "\n";

size_t r = fShapeX.size();

// Strides
auto stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
auto stridesI = UTILITY::ComputeStrideFromShape(fShapeI);

// case input_index_shape == rank of input
size_t k = fShapeI.back().dim;

// Total number of index tuples = product of indices dims except last
std::vector<Dim> shapeIndFirst(fShapeI.begin(), fShapeI.begin()+ fShapeI.size()-1);
auto num_index_tuples = ConvertDimShapeToLength(shapeIndFirst);

//slice size (is product of input from k to r)
std::vector<Dim> shapeSlice(fShapeX.begin()+k, fShapeX.end());
auto slice_size = ConvertDimShapeToLength(shapeSlice);

auto data_length = ConvertDimShapeToLength(fShapeX);

//step1: input->output
out << SP << "// Step 1: copy input data to output\n";
out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << data_length << ", tensor_" << fNY << ");\n";

// Step 2: Emit strides as a static constexpr array
out << SP << "// Step 2: data strides (row-major)\n";
out << SP << "static constexpr int64_t " << opName << "_data_strides[" << r << "] = {";
for (size_t i = 0; i < r; ++i)
out << stridesX[i] << (i + 1 < r ? ", " : "");
out << "};\n\n";

// Step 3: Scatter loop
out << SP << "// Step 3: scatter updates into output\n";
out << SP << "for (int64_t idx = 0; idx < " << num_index_tuples << "; idx++) {\n";

// Resolve flat data offset from k-dimensional index tuple
out << SP << SP << "int64_t data_offset = 0;\n";
for (size_t dim = 0; dim < k; ++dim) {
out << SP << SP << "{\n";
out << SP << SP << SP << "int64_t coord = tensor_" << fNI
<< "[idx * " << k << " + " << dim << "];\n";
// Support negative indices
out << SP << SP << SP << "if (coord < 0) coord += " << fShapeX[dim] << ";\n";
out << SP << SP << SP << "data_offset += coord * "
<< opName << "_data_strides[" << dim << "];\n";
out << SP << SP << "}\n";
}

// Apply updates with reduction
out << SP << SP << "for (int64_t s = 0; s < " << slice_size << "; s++) {\n";
out << SP << SP << SP << "auto upd = tensor_" << fNU
<< "[idx * " << slice_size << " + s];\n";

if (fReduction.empty() || fReduction == "none") {
out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = upd;\n";
} else if (fReduction == "add") {
out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] += upd;\n";
} else if (fReduction == "mul") {
out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] *= upd;\n";
} else if (fReduction == "min") {
out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] = "
<< "std::min(tensor_" << fNY << "[data_offset + s], upd);\n";
} else if (fReduction == "max") {
out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = "
<< "std::max(tensor_" << fNY << "[data_offset + s], upd);\n";
} else {
throw std::runtime_error(
"TMVA SOFIE ScatterND: unsupported reduction '" + fReduction + "'");
}

out << SP << SP << "}\n"; // end slice loop
out << SP << "}\n"; // end index tuple loop

return out.str();
}

};

}//SOFIE
}//Experimental
}//TMVA


#endif //TMVA_SOFIE_ROPERATOR_RELU
54 changes: 54 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -3006,3 +3006,57 @@ TEST(ONNX, NotIsNaN)
}
}

TEST(ONNX, ScatterND_1)
{
// test 1-D scatter (k=1, scalar slice)
std::vector<float> input = {1.,2.,3.,4.,5.}; // shape {5}
std::vector<int64_t> indices = { 0, 2, 4}; // shape {3,1}
std::vector<float> updates = { 10.,30.,50.}; // shape {3}
std::vector<float> correct_output = {10., 2., 30., 4., 50.};

ASSERT_INCLUDE_AND_RUN(std::vector<float>, "ScatterND_1", input, indices, updates);

// Checking output size
EXPECT_EQ(output.size(), correct_output.size());
// Checking output
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE);
}
}

TEST(ONNX, ScatterND_2)
{
// test 2-d Scatter - scatter rows - reduction = 'add
std::vector<float> input = {1.,1.,2.,2.,3.,3.}; // shape {3,2}
std::vector<int64_t> indices = { 0, 1}; // shape {2,1}
std::vector<float> updates = { 10.,10.,20.,20.}; // shape { 2,2}
std::vector<float> correct_output = {11., 11., 22., 22., 3., 3.};

ASSERT_INCLUDE_AND_RUN(std::vector<float>, "ScatterND_2", input, indices, updates);

// Checking output size
EXPECT_EQ(output.size(), correct_output.size());
// Checking output
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE);
}
}

TEST(ONNX, ScatterND_3)
{
// test element wise scatter (k==rank input) reduction = 'mul'
std::vector<float> input = {1.,2.,3.,4.}; // shape {2,2}
std::vector<int64_t> indices = { 0,0, 1,1}; // shape {2,2}
std::vector<float> updates = { 11.,22.}; // shape { 2}
std::vector<float> correct_output = {11., 2., 3., 88.};

ASSERT_INCLUDE_AND_RUN(std::vector<float>, "ScatterND_3", input, indices, updates);

// Checking output size
EXPECT_EQ(output.size(), correct_output.size());
// Checking output
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE);
}
}

21 changes: 21 additions & 0 deletions tmva/sofie/test/input_models/ScatterND_1.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
 onnx-example:”
+
data
indices
updatesoutput" ScatterND TestGraphZ
data


Z
indices


Z
updates


b
output


B
Expand Down
22 changes: 22 additions & 0 deletions tmva/sofie/test/input_models/ScatterND_2.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
 onnx-example:µ
@
data
indices
updatesoutput" ScatterND*
reduction"add  TestGraphZ
data


Z
indices


Z
updates


b
output


B
Expand Down
22 changes: 22 additions & 0 deletions tmva/sofie/test/input_models/ScatterND_3.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
 onnx-example:±
@
data
indices
updatesoutput" ScatterND*
reduction"mul  TestGraphZ
data


Z
indices


Z
updates


b
output


B
Expand Down
1 change: 1 addition & 0 deletions tmva/sofie_parsers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser
src/ParseEinsum.cxx
src/ParseRandom.cxx
src/ParseScatterElements.cxx
src/ParseScatterND.cxx
src/ParseNonZero.cxx
src/ParseNot.cxx
${PROTO_SRCS}
Expand Down
58 changes: 58 additions & 0 deletions tmva/sofie_parsers/src/ParseScatterND.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "TMVA/RModelParser_ONNX.hxx"
#include "TMVA/ROperator_ScatterND.hxx"
#include "onnx_proto3.pb.h"

namespace TMVA {
namespace Experimental {
namespace SOFIE {
Comment on lines +5 to +7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
namespace TMVA {
namespace Experimental {
namespace SOFIE {
namespace TMVA::Experimental::SOFIE {

Personal preference maybe, but I think this is more readable.


ParserFuncSignature ParseScatterND = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {

if (nodeproto.input_size() != 3) {
throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has invalid input size");
}
// data is input 0
if (!parser.IsRegisteredTensorType(nodeproto.input(0))){
throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(0)
+ " but its type is not yet registered");
}
if (!parser.IsRegisteredTensorType(nodeproto.input(1))){
throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(1)
+ " but its type is not yet registered");
}
if (!parser.IsRegisteredTensorType(nodeproto.input(2))){
throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(2)
+ " but its type is not yet registered");
}
Comment on lines +15 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that can also be done in a loop from zero to nodeproto.input_size().

ETensorType input_type = parser.GetTensorType(nodeproto.input(0));
if (parser.GetTensorType(nodeproto.input(2)) != input_type) {
throw std::runtime_error("TMVA::SOFIE ONNX parser ScatterND op has input tensors of different types: " +
nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) +
" and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type));
}

std::string reduction;
for (int i = 0; i < nodeproto.attribute_size(); i++) {
std::string attribute_name = nodeproto.attribute(i).name();
if (attribute_name == "reduction")
reduction = nodeproto.attribute(i).s();
}

std::unique_ptr<ROperator> op;
std::string output_name = nodeproto.output(0);

op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2),
output_name, reduction));
Comment on lines +41 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unique_ptr<ROperator> op;
std::string output_name = nodeproto.output(0);
op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2),
output_name, reduction));
auto op = std::make_unique<ROperator_ScatterND>(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2),
nodeproto.output(0), reduction));

We can use C++17 now, and it's better to have no naked new or delete, because these can be a red flag.


// Infer the output type
if (!parser.IsRegisteredTensorType(output_name)) {
parser.RegisterTensorType(output_name, input_type);
}

return op;
};


} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA
Loading
Loading