Skip to content

Commit bf49862

Browse files
committed
DPL Analysis: introduce binned expression
1 parent 9e322a9 commit bf49862

File tree

4 files changed

+165
-42
lines changed

4 files changed

+165
-42
lines changed

Framework/Core/include/Framework/ExpressionHelpers.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,6 @@ struct ColumnOperationSpec {
7575
result.type = type;
7676
}
7777
};
78-
79-
/// helper struct used to parse trees
80-
struct NodeRecord {
81-
/// pointer to the actual tree node
82-
Node* node_ptr = nullptr;
83-
size_t index = 0;
84-
explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
85-
bool operator!=(NodeRecord const& rhs)
86-
{
87-
return this->node_ptr != rhs.node_ptr;
88-
}
89-
};
9078
} // namespace o2::framework::expressions
9179

9280
#endif // O2_FRAMEWORK_EXPRESSIONS_HELPERS_H_

Framework/Core/include/Framework/Expressions.h

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ struct LiteralNode {
114114
{
115115
}
116116

117+
LiteralNode(LiteralNode const& other) = default;
118+
117119
using var_t = LiteralValue::stored_type;
118120
var_t value;
119121
atype::type type = atype::NA;
@@ -132,6 +134,7 @@ struct BindingNode {
132134
/// An expression tree node corresponding to binary or unary operation
133135
struct OpNode {
134136
OpNode(BasicOp op_) : op{op_} {}
137+
OpNode(OpNode const& other) = default;
135138
BasicOp op;
136139
};
137140

@@ -147,6 +150,8 @@ struct PlaceholderNode : LiteralNode {
147150
}
148151
}
149152

153+
PlaceholderNode(PlaceholderNode const& other) = default;
154+
150155
void reset(InitContext& context)
151156
{
152157
value = retrieve(context, name.data());
@@ -156,6 +161,28 @@ struct PlaceholderNode : LiteralNode {
156161
LiteralNode::var_t (*retrieve)(InitContext&, char const*);
157162
};
158163

164+
/// A placeholder node for parameters taken from an array
165+
struct ParameterNode : LiteralNode {
166+
ParameterNode(int index_ = -1)
167+
: LiteralNode((float)0),
168+
index{index_}
169+
{
170+
}
171+
172+
ParameterNode(ParameterNode const&) = default;
173+
174+
template <typename T>
175+
void reset(T value_, int index_ = -1)
176+
{
177+
(*static_cast<LiteralNode*>(this)) = LiteralNode(value_);
178+
if (index_ > 0) {
179+
index = index_;
180+
}
181+
}
182+
183+
int index;
184+
};
185+
159186
/// A conditional node
160187
struct ConditionalNode {
161188
};
@@ -178,6 +205,10 @@ struct Node {
178205
{
179206
}
180207

208+
Node(ParameterNode&& p) : self{std::forward<ParameterNode>(p)}, left{nullptr}, right{nullptr}, condition{nullptr}
209+
{
210+
}
211+
181212
Node(ConditionalNode op, Node&& then_, Node&& else_, Node&& condition_)
182213
: self{op},
183214
left{std::make_unique<Node>(std::forward<Node>(then_))},
@@ -196,16 +227,70 @@ struct Node {
196227
right{nullptr},
197228
condition{nullptr} {}
198229

230+
Node(Node const& other)
231+
: self{other.self},
232+
index{other.index}
233+
{
234+
if (other.left != nullptr) {
235+
left = std::make_unique<Node>(*other.left);
236+
}
237+
if (other.right != nullptr) {
238+
right = std::make_unique<Node>(*other.right);
239+
}
240+
if (other.condition != nullptr) {
241+
condition = std::make_unique<Node>(*other.condition);
242+
}
243+
}
244+
199245
/// variant with possible nodes
200-
using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode>;
246+
using self_t = std::variant<LiteralNode, BindingNode, OpNode, PlaceholderNode, ConditionalNode, ParameterNode>;
201247
self_t self;
202248
size_t index = 0;
203249
/// pointers to children
204-
std::unique_ptr<Node> left;
205-
std::unique_ptr<Node> right;
206-
std::unique_ptr<Node> condition;
250+
std::unique_ptr<Node> left = nullptr;
251+
std::unique_ptr<Node> right = nullptr;
252+
std::unique_ptr<Node> condition = nullptr;
207253
};
208254

255+
/// helper struct used to parse trees
256+
struct NodeRecord {
257+
/// pointer to the actual tree node
258+
Node* node_ptr = nullptr;
259+
size_t index = 0;
260+
explicit NodeRecord(Node* node_, size_t index_) : node_ptr(node_), index{index_} {}
261+
bool operator!=(NodeRecord const& rhs)
262+
{
263+
return this->node_ptr != rhs.node_ptr;
264+
}
265+
};
266+
267+
/// Tree-walker helper
268+
template <typename L>
269+
void walk(Node* head, L const& pred)
270+
{
271+
std::stack<NodeRecord> path;
272+
path.emplace(head, 0);
273+
while (!path.empty()) {
274+
auto& top = path.top();
275+
pred(top.node_ptr);
276+
277+
auto* leftp = top.node_ptr->left.get();
278+
auto* rightp = top.node_ptr->right.get();
279+
auto* condp = top.node_ptr->condition.get();
280+
path.pop();
281+
282+
if (leftp != nullptr) {
283+
path.emplace(leftp, 0);
284+
}
285+
if (rightp != nullptr) {
286+
path.emplace(rightp, 0);
287+
}
288+
if (condp != nullptr) {
289+
path.emplace(condp, 0);
290+
}
291+
}
292+
}
293+
209294
/// overloaded operators to build the tree from an expression
210295

211296
#define BINARY_OP_NODES(_operator_, _operation_) \
@@ -402,6 +487,47 @@ inline Node ifnode(Node&& condition_, Configurable<L1> const& then_, Configurabl
402487
return Node{ConditionalNode{}, PlaceholderNode{then_}, PlaceholderNode{else_}, std::forward<Node>(condition_)};
403488
}
404489

490+
/// Parameters
491+
inline Node par(int index)
492+
{
493+
return Node{ParameterNode{index}};
494+
}
495+
496+
/// binned functional
497+
template <typename T>
498+
inline Node binned(std::vector<T> const& binning, std::vector<T> const& parameters, Node&& binned, Node&& pexp, Node&& out)
499+
{
500+
int bins = binning.size() - 1;
501+
const auto binned_copy = binned;
502+
const auto out_copy = out;
503+
auto root = ifnode(Node{binned_copy} < binning[0], Node{out_copy}, LiteralNode{-1});
504+
root.right = std::make_unique<Node>(ifnode(Node{binned_copy} > binning[0] && Node{binned_copy} <= binning [1], updateParameters(pexp, bins, parameters, 0), LiteralNode{-1}));
505+
auto* current = root.right.get();
506+
for (auto i = 1; i < bins; ++i) {
507+
current->right = std::make_unique<Node>(ifnode(Node{binned_copy} <= binning[i + 1], updateParameters(pexp, bins, parameters, i), LiteralNode{-1}));
508+
current = current->right.get();
509+
}
510+
current->right = std::make_unique<Node>(out);
511+
512+
return root;
513+
}
514+
515+
template <typename T>
516+
Node updateParameters(Node const& pexp, int bins, std::vector<T> const& parameters, int bin)
517+
{
518+
Node result{pexp};
519+
auto updateParameter = [&bins, &parameters, &bin](Node* node)
520+
{
521+
if (node->self.index() == 5) {
522+
auto* n = std::get_if<5>(&node->self);
523+
n->reset(parameters[bin * bins + n->index]);
524+
}
525+
};
526+
walk(&result, updateParameter);
527+
return result;
528+
}
529+
530+
405531
/// A struct, containing the root of the expression tree
406532
struct Filter {
407533
Filter() = default;

Framework/Core/src/Expressions.cxx

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ struct PlaceholderNodeHelper {
118118
return DatumSpec{node.value, node.type};
119119
}
120120
};
121+
122+
struct ParameterNodeHelper {
123+
DatumSpec operator()(ParameterNode const& node) const
124+
{
125+
return DatumSpec{node.value, node.type};
126+
}
127+
};
121128
} // namespace
122129

123130
std::shared_ptr<arrow::DataType> concreteArrowType(atype::type type)
@@ -189,37 +196,13 @@ std::ostream& operator<<(std::ostream& os, DatumSpec const& spec)
189196

190197
void updatePlaceholders(Filter& filter, InitContext& context)
191198
{
192-
std::stack<NodeRecord> path;
193-
194-
// insert the top node into stack
195-
path.emplace(filter.node.get(), 0);
196-
197199
auto updateNode = [&](Node* node) {
198200
if (node->self.index() == 3) {
199201
std::get_if<3>(&node->self)->reset(context);
200202
}
201203
};
202204

203-
// while the stack is not empty
204-
while (!path.empty()) {
205-
auto& top = path.top();
206-
updateNode(top.node_ptr);
207-
208-
auto* leftp = top.node_ptr->left.get();
209-
auto* rightp = top.node_ptr->right.get();
210-
auto* condp = top.node_ptr->condition.get();
211-
path.pop();
212-
213-
if (leftp != nullptr) {
214-
path.emplace(leftp, 0);
215-
}
216-
if (rightp != nullptr) {
217-
path.emplace(rightp, 0);
218-
}
219-
if (condp != nullptr) {
220-
path.emplace(condp, 0);
221-
}
222-
}
205+
expressions::walk(filter.node.get(), updateNode);
223206
}
224207

225208
const char* stringType(atype::type t)
@@ -267,6 +250,7 @@ Operations createOperations(Filter const& expression)
267250
[lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); },
268251
[bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); },
269252
[ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); },
253+
[pr = ParameterNodeHelper{}](ParameterNode const& node){ return pr(node); },
270254
[](auto&&) { return DatumSpec{}; }},
271255
node->self);
272256
};

Framework/Core/test/test_Expressions.cxx

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "Framework/Configurable.h"
1313
#include "Framework/ExpressionHelpers.h"
1414
#include "Framework/AnalysisDataModel.h"
15-
#include "Framework/AODReaderHelpers.h"
1615
#include <catch_amalgamated.hpp>
1716
#include <arrow/util/config.h>
1817

@@ -283,3 +282,29 @@ TEST_CASE("TestConditionalExpressions")
283282
auto gandiva_filter2 = createFilter(schema2, gandiva_condition2);
284283
REQUIRE(gandiva_tree2->ToString() == "bool greater_than((float) fSigned1Pt, (const float) 0 raw(0)) && if (bool less_than(float absf((float) fEta), (const float) 1 raw(3f800000)) && if (bool less_than((float) fPt, (const float) 1 raw(3f800000))) { bool greater_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) } else { bool less_than((float) fPhi, (const float) 1.5708 raw(3fc90fdb)) }) { bool greater_than(float absf((float) fX), (const float) 1 raw(3f800000)) } else { bool greater_than(float absf((float) fY), (const float) 1 raw(3f800000)) }");
285284
}
285+
286+
TEST_CASE("TestBinnedExpressions")
287+
{
288+
std::vector<float> bins{0.5, 1.5, 2.5, 3.5, 4.5};
289+
std::vector<float> params{1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1, 3.2, 3.3, 4.0, 4.1, 4.2, 4.3};
290+
Projector p = binned(bins, params, o2::aod::track::pt, par(0) * o2::aod::track::x + par(1) * o2::aod::track::y + par(2) * o2::aod::track::z + par(3) * o2::aod::track::phi, LiteralNode{0.f});
291+
auto pspecs = createOperations(p);
292+
auto schema = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Pt::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField(), o2::aod::track::Phi::asArrowField()});
293+
auto tree = createExpressionTree(pspecs, schema);
294+
REQUIRE(tree->ToString() == "if (bool less_than((float) fPt, (const float) 0.5 raw(3f000000))) { (const float) 0 raw(0) } else { if (bool greater_than((float) fPt, (const float) 0.5 raw(3f000000)) && bool less_than_or_equal_to((float) fPt, (const float) 1.5 raw(3fc00000))) { float add(float add(float add(float multiply((const float) 1 raw(3f800000), (float) fX), float multiply((const float) 1.1 raw(3f8ccccd), (float) fY)), float multiply((const float) 1.2 raw(3f99999a), (float) fZ)), float multiply((const float) 1.3 raw(3fa66666), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 2.5 raw(40200000))) { float add(float add(float add(float multiply((const float) 2 raw(40000000), (float) fX), float multiply((const float) 2.1 raw(40066666), (float) fY)), float multiply((const float) 2.2 raw(400ccccd), (float) fZ)), float multiply((const float) 2.3 raw(40133333), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 3.5 raw(40600000))) { float add(float add(float add(float multiply((const float) 3 raw(40400000), (float) fX), float multiply((const float) 3.1 raw(40466666), (float) fY)), float multiply((const float) 3.2 raw(404ccccd), (float) fZ)), float multiply((const float) 3.3 raw(40533333), (float) fPhi)) } else { if (bool less_than_or_equal_to((float) fPt, (const float) 4.5 raw(40900000))) { float add(float add(float add(float multiply((const float) 4 raw(40800000), (float) fX), float multiply((const float) 4.1 raw(40833333), (float) fY)), float multiply((const float) 4.2 raw(40866666), (float) fZ)), float multiply((const float) 4.3 raw(4089999a), (float) fPhi)) } else { (const float) 0 raw(0) } } } } }");
295+
296+
std::vector<float> binning{0, o2::constants::math::PIHalf, o2::constants::math::PI, o2::constants::math::PI + o2::constants::math::PIHalf, o2::constants::math::TwoPI};
297+
std::vector<float> parameters{1.0, 1.1, 1.2, 1.3, // par 0
298+
2.0, 2.1, 2.2, 2.3, // par 1
299+
3.0, 3.1, 3.2, 3.3, // par 2
300+
4.0, 4.1, 4.2, 4.3}; // par 3
301+
302+
Projector p2 = binned((std::vector<float>)binning,
303+
(std::vector<float>)parameters,
304+
o2::aod::track::phi, par(0) * o2::aod::track::x * o2::aod::track::x + par(1) * o2::aod::track::y * o2::aod::track::y + par(2) * o2::aod::track::z * o2::aod::track::z,
305+
LiteralNode{-1.f});
306+
auto p2specs = createOperations(p2);
307+
auto schema2 = std::make_shared<arrow::Schema>(std::vector{o2::aod::track::Phi::asArrowField(), o2::aod::track::X::asArrowField(), o2::aod::track::Y::asArrowField(), o2::aod::track::Z::asArrowField()});
308+
auto tree2 = createExpressionTree(p2specs, schema2);
309+
REQUIRE(tree2->ToString() == "if (bool less_than((float) fPhi, (const float) 0 raw(0))) { (const float) -1 raw(bf800000) } else { if (bool greater_than((float) fPhi, (const float) 0 raw(0)) && bool less_than_or_equal_to((float) fPhi, (const float) 1.5708 raw(3fc90fdb))) { float add(float add(float multiply(float multiply((const float) 1 raw(3f800000), (float) fX), (float) fX), float multiply(float multiply((const float) 1.1 raw(3f8ccccd), (float) fY), (float) fY)), float multiply(float multiply((const float) 1.2 raw(3f99999a), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 3.14159 raw(40490fdb))) { float add(float add(float multiply(float multiply((const float) 2 raw(40000000), (float) fX), (float) fX), float multiply(float multiply((const float) 2.1 raw(40066666), (float) fY), (float) fY)), float multiply(float multiply((const float) 2.2 raw(400ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 4.71239 raw(4096cbe4))) { float add(float add(float multiply(float multiply((const float) 3 raw(40400000), (float) fX), (float) fX), float multiply(float multiply((const float) 3.1 raw(40466666), (float) fY), (float) fY)), float multiply(float multiply((const float) 3.2 raw(404ccccd), (float) fZ), (float) fZ)) } else { if (bool less_than_or_equal_to((float) fPhi, (const float) 6.28319 raw(40c90fdb))) { float add(float add(float multiply(float multiply((const float) 4 raw(40800000), (float) fX), (float) fX), float multiply(float multiply((const float) 4.1 raw(40833333), (float) fY), (float) fY)), float multiply(float multiply((const float) 4.2 raw(40866666), (float) fZ), (float) fZ)) } else { (const float) -1 raw(bf800000) } } } } }");
310+
}

0 commit comments

Comments
 (0)