@@ -41,6 +41,7 @@ class Projector;
4141#include < string>
4242#include < memory>
4343#include < set>
44+ #include < stack>
4445namespace gandiva
4546{
4647using Selection = std::shared_ptr<gandiva::SelectionVector>;
@@ -114,6 +115,8 @@ struct LiteralNode {
114115 {
115116 }
116117
118+ LiteralNode (LiteralNode const & other) = default ;
119+
117120 using var_t = LiteralValue::stored_type;
118121 var_t value;
119122 atype::type type = atype::NA;
@@ -132,6 +135,7 @@ struct BindingNode {
132135// / An expression tree node corresponding to binary or unary operation
133136struct OpNode {
134137 OpNode (BasicOp op_) : op{op_} {}
138+ OpNode (OpNode const & other) = default ;
135139 BasicOp op;
136140};
137141
@@ -147,6 +151,8 @@ struct PlaceholderNode : LiteralNode {
147151 }
148152 }
149153
154+ PlaceholderNode (PlaceholderNode const & other) = default ;
155+
150156 void reset (InitContext& context)
151157 {
152158 value = retrieve (context, name.data ());
@@ -156,6 +162,28 @@ struct PlaceholderNode : LiteralNode {
156162 LiteralNode::var_t (*retrieve)(InitContext&, char const *);
157163};
158164
165+ // / A placeholder node for parameters taken from an array
166+ struct ParameterNode : LiteralNode {
167+ ParameterNode (int index_ = -1 )
168+ : LiteralNode((float )0 ),
169+ index{index_}
170+ {
171+ }
172+
173+ ParameterNode (ParameterNode const &) = default ;
174+
175+ template <typename T>
176+ void reset (T value_, int index_ = -1 )
177+ {
178+ (*static_cast <LiteralNode*>(this )) = LiteralNode (value_);
179+ if (index_ > 0 ) {
180+ index = index_;
181+ }
182+ }
183+
184+ int index;
185+ };
186+
159187// / A conditional node
160188struct ConditionalNode {
161189};
@@ -178,6 +206,10 @@ struct Node {
178206 {
179207 }
180208
209+ Node (ParameterNode&& p) : self{std::forward<ParameterNode>(p)}, left{nullptr }, right{nullptr }, condition{nullptr }
210+ {
211+ }
212+
181213 Node (ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_)
182214 : self{op},
183215 left{std::make_unique<Node>(std::forward<Node>(then_))},
@@ -196,16 +228,70 @@ struct Node {
196228 right{nullptr },
197229 condition{nullptr } {}
198230
231+ Node (Node const & other)
232+ : self{other.self },
233+ index{other.index }
234+ {
235+ if (other.left != nullptr ) {
236+ left = std::make_unique<Node>(*other.left );
237+ }
238+ if (other.right != nullptr ) {
239+ right = std::make_unique<Node>(*other.right );
240+ }
241+ if (other.condition != nullptr ) {
242+ condition = std::make_unique<Node>(*other.condition );
243+ }
244+ }
245+
199246 // / variant with possible nodes
200- using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode>;
247+ using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode, ParameterNode >;
201248 self_t self;
202249 size_t index = 0 ;
203250 // / pointers to children
204- std::unique_ptr<Node> left;
205- std::unique_ptr<Node> right;
206- std::unique_ptr<Node> condition;
251+ std::unique_ptr<Node> left = nullptr ;
252+ std::unique_ptr<Node> right = nullptr ;
253+ std::unique_ptr<Node> condition = nullptr ;
254+ };
255+
256+ // / helper struct used to parse trees
257+ struct NodeRecord {
258+ // / pointer to the actual tree node
259+ Node* node_ptr = nullptr ;
260+ size_t index = 0 ;
261+ explicit NodeRecord (Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
262+ bool operator !=(NodeRecord const & rhs)
263+ {
264+ return this ->node_ptr != rhs.node_ptr ;
265+ }
207266};
208267
268+ // / Tree-walker helper
269+ template <typename L>
270+ void walk (Node* head, L const & pred)
271+ {
272+ std::stack<NodeRecord> path;
273+ path.emplace (head, 0 );
274+ while (!path.empty ()) {
275+ auto & top = path.top ();
276+ pred (top.node_ptr );
277+
278+ auto * leftp = top.node_ptr ->left .get ();
279+ auto * rightp = top.node_ptr ->right .get ();
280+ auto * condp = top.node_ptr ->condition .get ();
281+ path.pop ();
282+
283+ if (leftp != nullptr ) {
284+ path.emplace (leftp, 0 );
285+ }
286+ if (rightp != nullptr ) {
287+ path.emplace (rightp, 0 );
288+ }
289+ if (condp != nullptr ) {
290+ path.emplace (condp, 0 );
291+ }
292+ }
293+ }
294+
209295// / overloaded operators to build the tree from an expression
210296
211297#define BINARY_OP_NODES (_operator_, _operation_ ) \
@@ -402,6 +488,43 @@ inline Node ifnode(Node&& condition_, Configurable<L1> const& then_, Configurabl
402488 return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::forward<Node>(condition_)};
403489}
404490
491+ // / Parameters
492+ inline Node par (int index)
493+ {
494+ return Node{ParameterNode{index}};
495+ }
496+
497+ // / binned functional
498+ template <typename T>
499+ inline Node binned (std::vector<T> const & binning, std::vector<T> const & parameters, Node&& binned, Node&& pexp, Node&& out)
500+ {
501+ int bins = binning.size () - 1 ;
502+ const auto binned_copy = binned;
503+ const auto out_copy = out;
504+ auto root = ifnode (Node{binned_copy} < binning[0 ], Node{out_copy}, LiteralNode{-1 });
505+ auto * current = &root;
506+ for (auto i = 0 ; i < bins; ++i) {
507+ current->right = std::make_unique<Node>(ifnode (Node{binned_copy} < binning[i + 1 ], updateParameters (pexp, bins, parameters, i), LiteralNode{-1 }));
508+ current = current->right .get ();
509+ }
510+ current->right = std::make_unique<Node>(out);
511+ return root;
512+ }
513+
514+ template <typename T>
515+ Node updateParameters (Node const & pexp, int bins, std::vector<T> const & parameters, int bin)
516+ {
517+ Node result{pexp};
518+ auto updateParameter = [&bins, ¶meters, &bin](Node* node) {
519+ if (node->self .index () == 5 ) {
520+ auto * n = std::get_if<5 >(&node->self );
521+ n->reset (parameters[n->index * bins + bin]);
522+ }
523+ };
524+ walk (&result, updateParameter);
525+ return result;
526+ }
527+
405528// / A struct, containing the root of the expression tree
406529struct Filter {
407530 Filter () = default ;
0 commit comments