Skip to content

Commit 466ba06

Browse files
authored
DPL Analysis: introduce binned expression (#14174)
1 parent a946be8 commit 466ba06

File tree

4 files changed

+162
-42
lines changed

4 files changed

+162
-42
lines changed

Framework/Core/include/Framework/ExpressionHelpers.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,6 @@ struct ColumnOperationSpec {
7575
result.type = type;
7676
}
7777
};
78-
79-
/// helper struct used to parse trees
80-
struct NodeRecord {
81-
/// pointer to the actual tree node
82-
Node* node_ptr = nullptr;
83-
size_t index = 0;
84-
explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
85-
bool operator!=(NodeRecord const& rhs)
86-
{
87-
return this->node_ptr != rhs.node_ptr;
88-
}
89-
};
9078
} // namespace o2::framework::expressions
9179

9280
#endif // O2_FRAMEWORK_EXPRESSIONS_HELPERS_H_

Framework/Core/include/Framework/Expressions.h

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Projector;
4141
#include <string>
4242
#include <memory>
4343
#include <set>
44+
#include <stack>
4445
namespace gandiva
4546
{
4647
using Selection = std::shared_ptr<gandiva::SelectionVector>;
@@ -114,6 +115,8 @@ struct LiteralNode {
114115
{
115116
}
116117

118+
LiteralNode(LiteralNode const& other) = default;
119+
117120
using var_t = LiteralValue::stored_type;
118121
var_t value;
119122
atype::type type = atype::NA;
@@ -132,6 +135,7 @@ struct BindingNode {
132135
/// An expression tree node corresponding to binary or unary operation
133136
struct OpNode {
134137
OpNode(BasicOp op_) : op{op_} {}
138+
OpNode(OpNode const& other) = default;
135139
BasicOp op;
136140
};
137141

@@ -147,6 +151,8 @@ struct PlaceholderNode : LiteralNode {
147151
}
148152
}
149153

154+
PlaceholderNode(PlaceholderNode const& other) = default;
155+
150156
void reset(InitContext& context)
151157
{
152158
value = retrieve(context, name.data());
@@ -156,6 +162,28 @@ struct PlaceholderNode : LiteralNode {
156162
LiteralNode::var_t (*retrieve)(InitContext&, char const*);
157163
};
158164

165+
/// A placeholder node for parameters taken from an array
166+
struct ParameterNode : LiteralNode {
167+
ParameterNode(int index_ = -1)
168+
: LiteralNode((float)0),
169+
index{index_}
170+
{
171+
}
172+
173+
ParameterNode(ParameterNode const&) = default;
174+
175+
template <typename T>
176+
void reset(T value_, int index_ = -1)
177+
{
178+
(*static_cast<LiteralNode*>(this)) = LiteralNode(value_);
179+
if (index_ > 0) {
180+
index = index_;
181+
}
182+
}
183+
184+
int index;
185+
};
186+
159187
/// A conditional node
160188
struct ConditionalNode {
161189
};
@@ -178,6 +206,10 @@ struct Node {
178206
{
179207
}
180208

209+
Node(ParameterNode&& p) : self{std::forward<ParameterNode>(p)}, left{nullptr}, right{nullptr}, condition{nullptr}
210+
{
211+
}
212+
181213
Node(ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_)
182214
: self{op},
183215
left{std::make_unique<Node>(std::forward<Node>(then_))},
@@ -196,16 +228,70 @@ struct Node {
196228
right{nullptr},
197229
condition{nullptr} {}
198230

231+
Node(Node const& other)
232+
: self{other.self},
233+
index{other.index}
234+
{
235+
if (other.left != nullptr) {
236+
left = std::make_unique<Node>(*other.left);
237+
}
238+
if (other.right != nullptr) {
239+
right = std::make_unique<Node>(*other.right);
240+
}
241+
if (other.condition != nullptr) {
242+
condition = std::make_unique<Node>(*other.condition);
243+
}
244+
}
245+
199246
/// variant with possible nodes
200-
using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode>;
247+
using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode, ParameterNode>;
201248
self_t self;
202249
size_t index = 0;
203250
/// pointers to children
204-
std::unique_ptr<Node> left;
205-
std::unique_ptr<Node> right;
206-
std::unique_ptr<Node> condition;
251+
std::unique_ptr<Node> left = nullptr;
252+
std::unique_ptr<Node> right = nullptr;
253+
std::unique_ptr<Node> condition = nullptr;
254+
};
255+
256+
/// helper struct used to parse trees
257+
struct NodeRecord {
258+
/// pointer to the actual tree node
259+
Node* node_ptr = nullptr;
260+
size_t index = 0;
261+
explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
262+
bool operator!=(NodeRecord const& rhs)
263+
{
264+
return this->node_ptr != rhs.node_ptr;
265+
}
207266
};
208267

268+
/// Tree-walker helper
269+
template <typename L>
270+
void walk(Node* head, L const& pred)
271+
{
272+
std::stack<NodeRecord> path;
273+
path.emplace(head, 0);
274+
while (!path.empty()) {
275+
auto& top = path.top();
276+
pred(top.node_ptr);
277+
278+
auto* leftp = top.node_ptr->left.get();
279+
auto* rightp = top.node_ptr->right.get();
280+
auto* condp = top.node_ptr->condition.get();
281+
path.pop();
282+
283+
if (leftp != nullptr) {
284+
path.emplace(leftp, 0);
285+
}
286+
if (rightp != nullptr) {
287+
path.emplace(rightp, 0);
288+
}
289+
if (condp != nullptr) {
290+
path.emplace(condp, 0);
291+
}
292+
}
293+
}
294+
209295
/// overloaded operators to build the tree from an expression
210296

211297
#define BINARY_OP_NODES(_operator_, _operation_) \
@@ -402,6 +488,43 @@ inline Node ifnode(Node&& condition_, Configurable<L1> const& then_, Configurabl
402488
return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::forward<Node>(condition_)};
403489
}
404490

491+
/// Parameters
492+
inline Node par(int index)
493+
{
494+
return Node{ParameterNode{index}};
495+
}
496+
497+
/// binned functional
498+
template <typename T>
499+
inline Node binned(std::vector<T> const& binning, std::vector<T> const& parameters, Node&& binned, Node&& pexp, Node&& out)
500+
{
501+
int bins = binning.size() - 1;
502+
const auto binned_copy = binned;
503+
const auto out_copy = out;
504+
auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1});
505+
auto* current = &root;
506+
for (auto i = 0; i < bins; ++i) {
507+
current->right = std::make_unique<Node>(ifnode(Node{binned_copy} < binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1}));
508+
current = current->right.get();
509+
}
510+
current->right = std::make_unique<Node>(out);
511+
return root;
512+
}
513+
514+
template <typename T>
515+
Node updateParameters(Node const& pexp, int bins, std::vector<T> const& parameters, int bin)
516+
{
517+
Node result{pexp};
518+
auto updateParameter = [&bins, &parameters, &bin](Node* node) {
519+
if (node->self.index() == 5) {
520+
auto* n = std::get_if<5>(&node->self);
521+
n->reset(parameters[n->index * bins + bin]);
522+
}
523+
};
524+
walk(&result, updateParameter);
525+
return result;
526+
}
527+
405528
/// A struct, containing the root of the expression tree
406529
struct Filter {
407530
Filter() = default;

Framework/Core/src/Expressions.cxx

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ struct PlaceholderNodeHelper {
118118
return DatumSpec{node.value, node.type};
119119
}
120120
};
121+
122+
struct ParameterNodeHelper {
123+
DatumSpec operator()(ParameterNode const& node) const
124+
{
125+
return DatumSpec{node.value, node.type};
126+
}
127+
};
121128
} // namespace
122129

123130
std::shared_ptr<arrow::DataType> concreteArrowType(atype::type type)
@@ -189,37 +196,13 @@ std::ostream& operator<<(std::ostream& os, DatumSpec const& spec)
189196

190197
void updatePlaceholders(Filter& filter, InitContext& context)
191198
{
192-
std::stack<NodeRecord> path;
193-
194-
// insert the top node into stack
195-
path.emplace(filter.node.get(), 0);
196-
197199
auto updateNode = [&](Node* node) {
198200
if (node->self.index() == 3) {
199201
std::get_if<3>(&node->self)->reset(context);
200202
}
201203
};
202204

203-
// while the stack is not empty
204-
while (!path.empty()) {
205-
auto& top = path.top();
206-
updateNode(top.node_ptr);
207-
208-
auto* leftp = top.node_ptr->left.get();
209-
auto* rightp = top.node_ptr->right.get();
210-
auto* condp = top.node_ptr->condition.get();
211-
path.pop();
212-
213-
if (leftp != nullptr) {
214-
path.emplace(leftp, 0);
215-
}
216-
if (rightp != nullptr) {
217-
path.emplace(rightp, 0);
218-
}
219-
if (condp != nullptr) {
220-
path.emplace(condp, 0);
221-
}
222-
}
205+
expressions::walk(filter.node.get(), updateNode);
223206
}
224207

225208
const char* stringType(atype::type t)
@@ -267,6 +250,7 @@ Operations createOperations(Filter const& expression)
267250
[lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); },
268251
[bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); },
269252
[ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); },
253+
[pr = ParameterNodeHelper{}](ParameterNode const& node) { return pr(node); },
270254
[](auto&&) { return DatumSpec{}; }},
271255
node->self);
272256
};

Framework/Core/test/test_Expressions.cxx

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "Framework/Configurable.h"
1313
#include "Framework/ExpressionHelpers.h"
1414
#include "Framework/AnalysisDataModel.h"
15-
#include "Framework/AODReaderHelpers.h"
1615
#include <catch_amalgamated.hpp>
1716
#include <arrow/util/config.h>
1817

@@ -283,3 +282,29 @@ TEST_CASE("TestConditionalExpressions")
283282
auto gandiva_filter2 = createFilter(schema2, gandiva_condition2);
284283
REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }");
285284
}
285+
286+
TEST_CASE("TestBinnedExpressions")
287+
{
288+
std::vector<float> bins{0.5, 1.5, 2.5, 3.5, 4.5};
289+
std::vector<float> params{1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3, 4.0, 4.1, 4.2, 4.3};
290+
Projector p = binned(bins, params, o2::aod::track::pt, par(0) * o2::aod::track::x + par(1) * o2::aod::track::y + par(2) * o2::aod::track::z + par(3) * o2::aod::track::phi, LiteralNode{0.f});
291+
auto pspecs = createOperations(p);
292+
auto schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()});
293+
auto tree = createExpressionTree(pspecs, schema);
294+
REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool less_than((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 2 raw(40000000), (float) fY)), float multiply((const float) 3 raw(40400000), (float) fZ)), float multiply((const float) 4 raw(40800000), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 3.1 raw(40466666), (float) fZ)), float multiply((const float) 4.1 raw(40833333), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 1.2 raw(3f99999a), (float) fX), float multiply((const float) 2.2 raw(400ccccd), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 4.2 raw(40866666), (float) fPhi)) } else { if (bool less_than((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 1.3 raw(3fa66666), (float) fX), float multiply((const float) 2.3 raw(40133333), (float) fY)), float multiply((const float) 3.3 raw(40533333), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }");
295+
296+
std::vector<float> binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI};
297+
std::vector<float> parameters{1.0, 1.1, 1.2, 1.3, // par 0
298+
2.0, 2.1, 2.2, 2.3, // par 1
299+
3.0, 3.1, 3.2, 3.3, // par 2
300+
4.0, 4.1, 4.2, 4.3}; // par 3
301+
302+
Projector p2 = binned((std::vector<float>)binning,
303+
(std::vector<float>)parameters,
304+
o2::aod::track::phi, par(0) * o2::aod::track::x * o2::aod::track::x + par(1) * o2::aod::track::y * o2::aod::track::y + par(2) * o2::aod::track::z * o2::aod::track::z,
305+
LiteralNode{-1.f});
306+
auto p2specs = createOperations(p2);
307+
auto schema2 = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()});
308+
auto tree2 = createExpressionTree(p2specs, schema2);
309+
REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 2 raw(40000000), (float) fY), (float) fY)), float multiply(float multiply((const float) 3 raw(40400000), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fX), (float) fX), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 1.3 raw(3fa66666), (float) fX), (float) fX), float multiply(float multiply((const float) 2.3 raw(40133333), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.3 raw(40533333), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }");
310+
}

0 commit comments

Comments
 (0)