Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions Framework/Core/include/Framework/Expressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct ExpressionInfo {

namespace o2::framework::expressions
{
void unknownParameterUsed(const char* name);
const char* stringType(atype::type t);

template <typename... T>
Expand Down Expand Up @@ -147,7 +148,7 @@ struct PlaceholderNode : LiteralNode {
if constexpr (variant_trait_v<typename std::decay<T>::type> != VariantType::Unknown) {
retrieve = [](InitContext& context, char const* name) { return LiteralNode::var_t{context.options().get<T>(name)}; };
} else {
runtime_error("Unknown parameter used in expression.");
unknownParameterUsed(name.c_str());
}
}

Expand Down Expand Up @@ -188,6 +189,19 @@ struct ParameterNode : LiteralNode {
struct ConditionalNode {
};

/// concepts
template <typename T>
concept is_literal_like = std::same_as<T, LiteralNode> || std::same_as<T, PlaceholderNode> || std::same_as<T, ParameterNode>;

template <typename T>
concept is_binding = std::same_as<T, BindingNode>;

template <typename T>
concept is_operation = std::same_as<T, OpNode>;

template <typename T>
concept is_conditional = std::same_as<T, ConditionalNode>;

/// A generic tree node
struct Node {
Node(LiteralNode&& v) : self{std::forward<LiteralNode>(v)}, left{nullptr}, right{nullptr}, condition{nullptr}
Expand Down Expand Up @@ -267,7 +281,7 @@ struct NodeRecord {

/// Tree-walker helper
template <typename L>
void walk(Node* head, L const& pred)
void walk(Node* head, L&& pred)
{
std::stack<NodeRecord> path;
path.emplace(head, 0);
Expand Down Expand Up @@ -512,16 +526,15 @@ inline Node binned(std::vector<T> const& binning, std::vector<T> const& paramete
}

template <typename T>
Node updateParameters(Node const& pexp, int bins, std::vector<T> const& parameters, int bin)
inline Node updateParameters(Node const& pexp, int bins, std::vector<T> const& parameters, int bin)
{
Node result{pexp};
auto updateParameter = [&bins, &parameters, &bin](Node* node) {
walk(&result, [&bins, &parameters, &bin](Node* node) {
if (node->self.index() == 5) {
auto* n = std::get_if<5>(&node->self);
n->reset(parameters[n->index * bins + bin]);
}
};
walk(&result, updateParameter);
});
return result;
}

Expand Down Expand Up @@ -594,12 +607,6 @@ gandiva::ExpressionPtr makeExpression(gandiva::NodePtr node, gandiva::FieldPtr r
/// Update placeholder nodes from context
void updatePlaceholders(Filter& filter, InitContext& context);

template <typename... C>
std::vector<expressions::Projector> makeProjectors(framework::pack<C...>)
{
return {C::Projector()...};
}

std::shared_ptr<gandiva::Projector> createProjectorHelper(size_t nColumns, expressions::Projector* projectors,
std::shared_ptr<arrow::Schema> schema,
std::vector<std::shared_ptr<arrow::Field>> const& fields);
Expand Down
100 changes: 46 additions & 54 deletions Framework/Core/src/Expressions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ using namespace o2::framework;

namespace o2::framework::expressions
{
void unknownParameterUsed(const char* name)
{
runtime_error_f("Unknown parameter used in expression: %s", name);
}

/// a map between BasicOp and gandiva node definitions
/// note that logical 'and' and 'or' are created separately
Expand Down Expand Up @@ -89,43 +93,41 @@ size_t Filter::designateSubtrees(Node* node, size_t index)
return index;
}

namespace
template <typename T>
constexpr inline auto makeDatum(T const&)
{
struct LiteralNodeHelper {
DatumSpec operator()(LiteralNode const& node) const
{
return DatumSpec{node.value, node.type};
}
};
return DatumSpec{};
}

struct BindingNodeHelper {
DatumSpec operator()(BindingNode const& node) const
{
return DatumSpec{node.name, node.hash, node.type};
}
};
template <is_literal_like T>
constexpr inline auto makeDatum(T const& node)
{
return DatumSpec{node.value, node.type};
}

struct OpNodeHelper {
ColumnOperationSpec operator()(OpNode const& node) const
{
return ColumnOperationSpec{node.op};
}
};
template <is_binding T>
constexpr inline auto makeDatum(T const& node)
{
return DatumSpec{node.name, node.hash, node.type};
}

struct PlaceholderNodeHelper {
DatumSpec operator()(PlaceholderNode const& node) const
{
return DatumSpec{node.value, node.type};
}
};
template <typename T>
constexpr inline auto makeOp(T const&, size_t const&)
{
return ColumnOperationSpec{};
}

struct ParameterNodeHelper {
DatumSpec operator()(ParameterNode const& node) const
{
return DatumSpec{node.value, node.type};
}
};
} // namespace
template <is_operation T>
constexpr inline auto makeOp(T const& node, size_t const& index)
{
return ColumnOperationSpec{node.op, index};
}

template <is_conditional T>
constexpr inline auto makeOp(T const&, size_t const& index)
{
return ColumnOperationSpec{BasicOp::Conditional, index};
}

std::shared_ptr<arrow::DataType> concreteArrowType(atype::type type)
{
Expand Down Expand Up @@ -169,7 +171,7 @@ std::string upcastTo(atype::type f)
case atype::DOUBLE:
return "castFLOAT8";
default:
throw runtime_error_f("Do not know how to cast to %d", f);
throw runtime_error_f("Do not know how to cast to %s", stringType(f));
}
}

Expand All @@ -196,13 +198,11 @@ std::ostream& operator<<(std::ostream& os, DatumSpec const& spec)

void updatePlaceholders(Filter& filter, InitContext& context)
{
auto updateNode = [&](Node* node) {
expressions::walk(filter.node.get(), [&](Node* node) {
if (node->self.index() == 3) {
std::get_if<3>(&node->self)->reset(context);
}
};

expressions::walk(filter.node.get(), updateNode);
});
}

const char* stringType(atype::type t)
Expand Down Expand Up @@ -246,12 +246,7 @@ Operations createOperations(Filter const& expression)

auto processLeaf = [](Node const* const node) {
return std::visit(
overloaded{
[lh = LiteralNodeHelper{}](LiteralNode const& node) { return lh(node); },
[bh = BindingNodeHelper{}](BindingNode const& node) { return bh(node); },
[ph = PlaceholderNodeHelper{}](PlaceholderNode const& node) { return ph(node); },
[pr = ParameterNodeHelper{}](ParameterNode const& node) { return pr(node); },
[](auto&&) { return DatumSpec{}; }},
[](auto const& n) { return makeDatum(n); },
node->self);
};

Expand All @@ -266,10 +261,7 @@ Operations createOperations(Filter const& expression)
// create operation spec, pop the node and add its children
auto operationSpec =
std::visit(
overloaded{
[&](OpNode node) { return ColumnOperationSpec{node.op, top.node_ptr->index}; },
[&](ConditionalNode) { return ColumnOperationSpec{BasicOp::Conditional, top.node_ptr->index}; },
[](auto&&) { return ColumnOperationSpec{}; }},
[&](auto const& n) { return makeOp(n, top.node_ptr->index); },
top.node_ptr->self);

operationSpec.result = DatumSpec{top.index, operationSpec.type};
Expand Down Expand Up @@ -623,15 +615,15 @@ gandiva::NodePtr createExpressionTree(Operations const& opSpecs,
auto rightNode = datumNode(it->right);
auto condNode = datumNode(it->condition);

auto insertUpcastNode = [&](gandiva::NodePtr node, atype::type t) {
if (t != it->type) {
auto upcast = gandiva::TreeExprBuilder::MakeFunction(upcastTo(it->type), {node}, concreteArrowType(it->type));
auto insertUpcastNode = [](gandiva::NodePtr node, atype::type t0, atype::type t) {
if (t != t0) {
auto upcast = gandiva::TreeExprBuilder::MakeFunction(upcastTo(t0), {node}, concreteArrowType(t0));
node = upcast;
}
return node;
};

auto insertEqualizeUpcastNode = [&](gandiva::NodePtr& node1, gandiva::NodePtr& node2, atype::type t1, atype::type t2) {
auto insertEqualizeUpcastNode = [](gandiva::NodePtr& node1, gandiva::NodePtr& node2, atype::type t1, atype::type t2) {
if (t2 > t1) {
auto upcast = gandiva::TreeExprBuilder::MakeFunction(upcastTo(t2), {node1}, concreteArrowType(t2));
node1 = upcast;
Expand All @@ -656,14 +648,14 @@ gandiva::NodePtr createExpressionTree(Operations const& opSpecs,
default:
if (it->op < BasicOp::Sqrt) {
if (it->type != atype::BOOL) {
leftNode = insertUpcastNode(leftNode, it->left.type);
rightNode = insertUpcastNode(rightNode, it->right.type);
leftNode = insertUpcastNode(leftNode, it->type, it->left.type);
rightNode = insertUpcastNode(rightNode, it->type, it->right.type);
} else if (it->op == BasicOp::Equal || it->op == BasicOp::NotEqual) {
insertEqualizeUpcastNode(leftNode, rightNode, it->left.type, it->right.type);
}
temp_node = gandiva::TreeExprBuilder::MakeFunction(basicOperationsMap[it->op], {leftNode, rightNode}, concreteArrowType(it->type));
} else {
leftNode = insertUpcastNode(leftNode, it->left.type);
leftNode = insertUpcastNode(leftNode, it->type, it->left.type);
temp_node = gandiva::TreeExprBuilder::MakeFunction(basicOperationsMap[it->op], {leftNode}, concreteArrowType(it->type));
}
break;
Expand Down