Skip to content

Commit aa1fa27

Browse files
aalkinalibuild
andauthored
DPL Analysis: generalize aod-spawner (#14808)
Co-authored-by: ALICE Action Bot <alibuild@cern.ch>
1 parent 72e6d01 commit aa1fa27

15 files changed

+1203
-76
lines changed

Framework/Core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ o2_add_library(Framework
142142
src/Array2D.cxx
143143
src/Variant.cxx
144144
src/VariantJSONHelpers.cxx
145+
src/ExpressionJSONHelpers.cxx
145146
src/VariantPropertyTreeHelpers.cxx
146147
src/WorkflowCustomizationHelpers.cxx
147148
src/WorkflowHelpers.cxx

Framework/Core/include/Framework/AODReaderHelpers.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
#ifndef O2_FRAMEWORK_AODREADERHELPERS_H_
1313
#define O2_FRAMEWORK_AODREADERHELPERS_H_
1414

15-
#include "Framework/TableBuilder.h"
1615
#include "Framework/AlgorithmSpec.h"
17-
#include "Framework/Logger.h"
18-
#include "Framework/RootMessageContext.h"
1916
#include <uv.h>
2017

2118
namespace o2::framework::readers
@@ -24,7 +21,7 @@ namespace o2::framework::readers
2421

2522
struct AODReaderHelpers {
2623
static AlgorithmSpec rootFileReaderCallback();
27-
static AlgorithmSpec aodSpawnerCallback(std::vector<InputSpec>& requested);
24+
static AlgorithmSpec aodSpawnerCallback(ConfigContext const& ctx);
2825
static AlgorithmSpec indexBuilderCallback(std::vector<InputSpec>& requested);
2926
};
3027

Framework/Core/include/Framework/ASoA.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ struct TableIterator : IP, C... {
12701270

12711271
struct ArrowHelpers {
12721272
static std::shared_ptr<arrow::Table> joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const char* const> labels);
1273+
static std::shared_ptr<arrow::Table> joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const std::string> labels);
12731274
static std::shared_ptr<arrow::Table> concatTables(std::vector<std::shared_ptr<arrow::Table>>&& tables);
12741275
};
12751276

@@ -1293,7 +1294,14 @@ concept with_ccdb_urls = requires {
12931294
};
12941295

12951296
template <typename T>
1296-
concept with_base_table = not_void<typename aod::MetadataTrait<o2::aod::Hash<T::ref.desc_hash>>::metadata::base_table_t>;
1297+
concept with_base_table = requires {
1298+
typename aod::MetadataTrait<o2::aod::Hash<T::ref.desc_hash>>::metadata::base_table_t;
1299+
};
1300+
1301+
template <typename T>
1302+
concept with_expression_pack = requires {
1303+
typename T::expression_pack_t{};
1304+
};
12971305

12981306
template <size_t N1, std::array<TableRef, N1> os1, size_t N2, std::array<TableRef, N2> os2>
12991307
consteval bool is_compatible()

Framework/Core/include/Framework/AnalysisHelpers.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
#include "Framework/Traits.h"
2727

2828
#include <string>
29+
namespace o2::framework
30+
{
31+
std::string serializeProjectors(std::vector<framework::expressions::Projector>& projectors);
32+
std::string serializeSchema(std::shared_ptr<arrow::Schema>& schema);
33+
} // namespace o2::framework
34+
2935
namespace o2::soa
3036
{
3137
template <TableRef R>
@@ -97,6 +103,32 @@ constexpr auto getCCDBMetadata() -> std::vector<framework::ConfigParamSpec>
97103
{
98104
return {};
99105
}
106+
107+
template <soa::with_expression_pack T>
108+
constexpr auto getExpressionMetadata() -> std::vector<framework::ConfigParamSpec>
109+
{
110+
using expression_pack_t = T::expression_pack_t;
111+
112+
auto projectors = []<typename... C>(framework::pack<C...>) -> std::vector<framework::expressions::Projector> {
113+
std::vector<framework::expressions::Projector> result;
114+
(result.emplace_back(std::move(C::Projector())), ...);
115+
return result;
116+
}(expression_pack_t{});
117+
118+
auto schema = std::make_shared<arrow::Schema>(o2::soa::createFieldsFromColumns(expression_pack_t{}));
119+
120+
auto json = framework::serializeProjectors(projectors);
121+
return {framework::ConfigParamSpec{"projectors", framework::VariantType::String, json, {"\"\""}},
122+
framework::ConfigParamSpec{"schema", framework::VariantType::String, framework::serializeSchema(schema), {"\"\""}}};
123+
}
124+
125+
template <typename T>
126+
requires(!soa::with_expression_pack<T>)
127+
constexpr auto getExpressionMetadata() -> std::vector<framework::ConfigParamSpec>
128+
{
129+
return {};
130+
}
131+
100132
} // namespace
101133

102134
template <TableRef R>
@@ -107,6 +139,8 @@ constexpr auto tableRef2InputSpec()
107139
metadata.insert(metadata.end(), m.begin(), m.end());
108140
auto ccdbMetadata = getCCDBMetadata<typename o2::aod::MetadataTrait<o2::aod::Hash<R.desc_hash>>::metadata>();
109141
metadata.insert(metadata.end(), ccdbMetadata.begin(), ccdbMetadata.end());
142+
auto p = getExpressionMetadata<typename o2::aod::MetadataTrait<o2::aod::Hash<R.desc_hash>>::metadata>();
143+
metadata.insert(metadata.end(), p.begin(), p.end());
110144

111145
return framework::InputSpec{
112146
o2::aod::label<R>(),

Framework/Core/include/Framework/Expressions.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ std::string upcastTo(atype::type f);
110110

111111
/// An expression tree node corresponding to a literal value
112112
struct LiteralNode {
113+
using var_t = LiteralValue::stored_type;
114+
113115
LiteralNode()
114116
: value{-1},
115117
type{atype::INT32}
@@ -120,7 +122,12 @@ struct LiteralNode {
120122
{
121123
}
122124

123-
using var_t = LiteralValue::stored_type;
125+
LiteralNode(var_t v, atype::type t)
126+
: value{v},
127+
type{t}
128+
{
129+
}
130+
124131
var_t value;
125132
atype::type type = atype::NA;
126133
};
@@ -617,14 +624,19 @@ inline Node ncfg(T defaultValue, std::string path)
617624
struct Filter {
618625
Filter() = default;
619626

627+
Filter(std::unique_ptr<Node>&& ptr)
628+
{
629+
node = std::move(ptr);
630+
(void)designateSubtrees(node.get());
631+
}
632+
620633
Filter(Node&& node_) : node{std::make_unique<Node>(std::forward<Node>(node_))}
621634
{
622635
(void)designateSubtrees(node.get());
623636
}
624637

625638
Filter(Filter&& other) : node{std::forward<std::unique_ptr<Node>>(other.node)}
626639
{
627-
(void)designateSubtrees(node.get());
628640
}
629641

630642
Filter(std::string const& input_) : input{input_} {}

Framework/Core/include/Framework/TableBuilder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "arrow/type_traits.h"
1919

2020
// Apparently needs to be on top of the arrow includes.
21-
#include <sstream>
2221

2322
#include <arrow/chunked_array.h>
2423
#include <arrow/status.h>
@@ -796,6 +795,10 @@ auto makeEmptyTable(const char* name, framework::pack<Cs...> p)
796795
std::shared_ptr<arrow::Table> spawnerHelper(std::shared_ptr<arrow::Table> const& fullTable, std::shared_ptr<arrow::Schema> newSchema, size_t nColumns,
797796
expressions::Projector* projectors, const char* name, std::shared_ptr<gandiva::Projector>& projector);
798797

798+
std::shared_ptr<arrow::Table> spawnerHelper(std::shared_ptr<arrow::Table> const& fullTable, std::shared_ptr<arrow::Schema> newSchema,
799+
const char* name, size_t nColumns,
800+
const std::shared_ptr<gandiva::Projector>& projector);
801+
799802
/// Expression-based column generator to materialize columns
800803
template <aod::is_aod_hash D>
801804
requires(soa::has_configurable_extension<typename o2::aod::MetadataTrait<D>::metadata>)

Framework/Core/src/AODReaderHelpers.cxx

Lines changed: 131 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
#include "Framework/AODReaderHelpers.h"
1313
#include "Framework/AnalysisHelpers.h"
1414
#include "Framework/AnalysisDataModelHelpers.h"
15-
#include "Framework/DataProcessingHelpers.h"
1615
#include "Framework/ExpressionHelpers.h"
16+
#include "Framework/DataProcessingHelpers.h"
1717
#include "Framework/AlgorithmSpec.h"
1818
#include "Framework/ControlService.h"
1919
#include "Framework/CallbackService.h"
2020
#include "Framework/EndOfStreamContext.h"
2121
#include "Framework/DataSpecUtils.h"
22+
#include "ExpressionJSONHelpers.h"
23+
#include "Framework/ConfigContext.h"
24+
#include "Framework/AnalysisContext.h"
2225

2326
#include <Monitoring/Monitoring.h>
2427

@@ -44,28 +47,6 @@ auto setEOSCallback(InitContext& ic)
4447
});
4548
}
4649

47-
template <typename... Ts>
48-
static inline auto doExtractOriginal(framework::pack<Ts...>, ProcessingContext& pc)
49-
{
50-
if constexpr (sizeof...(Ts) == 1) {
51-
return pc.inputs().get<TableConsumer>(aod::MetadataTrait<framework::pack_element_t<0, framework::pack<Ts...>>>::metadata::tableLabel())->asArrowTable();
52-
} else {
53-
return std::vector{pc.inputs().get<TableConsumer>(aod::MetadataTrait<Ts>::metadata::tableLabel())->asArrowTable()...};
54-
}
55-
}
56-
57-
template <typename... Os>
58-
static inline auto extractOriginalsTuple(framework::pack<Os...>, ProcessingContext& pc)
59-
{
60-
return std::make_tuple(extractTypedOriginal<Os>(pc)...);
61-
}
62-
63-
template <typename... Os>
64-
static inline auto extractOriginalsVector(framework::pack<Os...>, ProcessingContext& pc)
65-
{
66-
return std::vector{extractOriginal<Os>(pc)...};
67-
}
68-
6950
template <size_t N, std::array<soa::TableRef, N> refs>
7051
static inline auto extractOriginals(ProcessingContext& pc)
7152
{
@@ -156,53 +137,137 @@ auto make_spawn(InputSpec const& input, ProcessingContext& pc)
156137
(typename metadata_t::expression_pack_t{});
157138
return o2::framework::spawner<D>(extractOriginals<sources.size(), sources>(pc), input.binding.c_str(), projectors.data(), projector, schema);
158139
}
140+
141+
struct Maker {
142+
std::string binding;
143+
std::vector<std::string> labels;
144+
std::vector<std::shared_ptr<gandiva::Expression>> expressions;
145+
std::shared_ptr<gandiva::Projector> projector = nullptr;
146+
std::shared_ptr<arrow::Schema> schema;
147+
148+
header::DataOrigin origin;
149+
header::DataDescription description;
150+
header::DataHeader::SubSpecificationType version;
151+
152+
std::shared_ptr<arrow::Table> make(ProcessingContext& pc)
153+
{
154+
std::vector<std::shared_ptr<arrow::Table>> originals;
155+
for (auto const& label : labels) {
156+
originals.push_back(pc.inputs().get<TableConsumer>(label)->asArrowTable());
157+
}
158+
auto fullTable = soa::ArrowHelpers::joinTables(std::move(originals), std::span{labels.begin(), labels.size()});
159+
if (projector == nullptr) {
160+
auto s = gandiva::Projector::Make(
161+
fullTable->schema(),
162+
expressions,
163+
&projector);
164+
if (!s.ok()) {
165+
throw o2::framework::runtime_error_f("Failed to create projector: %s", s.ToString().c_str());
166+
}
167+
}
168+
169+
return spawnerHelper(fullTable, schema, binding.c_str(), schema->num_fields(), projector);
170+
}
171+
};
172+
173+
struct Spawnable {
174+
std::string binding;
175+
std::vector<std::string> labels;
176+
std::vector<expressions::Projector> projectors;
177+
std::vector<std::shared_ptr<gandiva::Expression>> expressions;
178+
std::shared_ptr<arrow::Schema> outputSchema;
179+
std::shared_ptr<arrow::Schema> inputSchema;
180+
181+
header::DataOrigin origin;
182+
header::DataDescription description;
183+
header::DataHeader::SubSpecificationType version;
184+
185+
Spawnable(InputSpec const& spec)
186+
: binding{spec.binding}
187+
{
188+
auto&& [origin_, description_, version_] = DataSpecUtils::asConcreteDataMatcher(spec);
189+
origin = origin_;
190+
description = description_;
191+
version = version_;
192+
auto loc = std::find_if(spec.metadata.begin(), spec.metadata.end(), [](ConfigParamSpec const& cps) { return cps.name.compare("projectors") == 0; });
193+
std::stringstream iws(loc->defaultValue.get<std::string>());
194+
projectors = ExpressionJSONHelpers::read(iws);
195+
196+
loc = std::find_if(spec.metadata.begin(), spec.metadata.end(), [](ConfigParamSpec const& cps) { return cps.name.compare("schema") == 0; });
197+
iws.clear();
198+
iws.str(loc->defaultValue.get<std::string>());
199+
outputSchema = ArrowJSONHelpers::read(iws);
200+
201+
for (auto& i : spec.metadata) {
202+
if (i.name.starts_with("input:")) {
203+
labels.emplace_back(i.name.substr(6));
204+
}
205+
}
206+
207+
std::vector<std::shared_ptr<arrow::Field>> fields;
208+
for (auto& p : projectors) {
209+
expressions::walk(p.node.get(),
210+
[&fields](expressions::Node* n) mutable {
211+
if (n->self.index() == 1) {
212+
auto& b = std::get<expressions::BindingNode>(n->self);
213+
if (std::find_if(fields.begin(), fields.end(), [&b](std::shared_ptr<arrow::Field> const& field) { return field->name() == b.name; }) == fields.end()) {
214+
fields.emplace_back(std::make_shared<arrow::Field>(b.name, expressions::concreteArrowType(b.type)));
215+
}
216+
}
217+
});
218+
}
219+
inputSchema = std::make_shared<arrow::Schema>(fields);
220+
221+
int i = 0;
222+
for (auto& p : projectors) {
223+
expressions.push_back(
224+
expressions::makeExpression(
225+
expressions::createExpressionTree(
226+
expressions::createOperations(p),
227+
inputSchema),
228+
outputSchema->field(i)));
229+
++i;
230+
}
231+
}
232+
233+
std::shared_ptr<gandiva::Projector> makeProjector()
234+
{
235+
return expressions::createProjectorHelper(projectors.size(), projectors.data(), inputSchema, outputSchema->fields());
236+
}
237+
238+
Maker createMaker()
239+
{
240+
return {
241+
binding,
242+
labels,
243+
expressions,
244+
nullptr,
245+
outputSchema,
246+
origin,
247+
description,
248+
version};
249+
}
250+
};
251+
159252
} // namespace
160253

161-
AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(std::vector<InputSpec>& requested)
254+
AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(/*std::vector<InputSpec>& requested*/ ConfigContext const& ctx)
162255
{
163-
return AlgorithmSpec::InitCallback{[requested](InitContext& /*ic*/) {
164-
return [requested](ProcessingContext& pc) {
256+
auto& ac = ctx.services().get<AnalysisContext>();
257+
return AlgorithmSpec::InitCallback{[requested = ac.spawnerInputs](InitContext& /*ic*/) {
258+
std::vector<Spawnable> spawnables;
259+
for (auto& i : requested) {
260+
spawnables.emplace_back(i);
261+
}
262+
std::vector<Maker> makers;
263+
for (auto& s : spawnables) {
264+
makers.push_back(s.createMaker());
265+
}
266+
267+
return [makers](ProcessingContext& pc) mutable {
165268
auto outputs = pc.outputs();
166-
// spawn tables
167-
for (auto& input : requested) {
168-
auto&& [origin, description, version] = DataSpecUtils::asConcreteDataMatcher(input);
169-
if (description == header::DataDescription{"EXTRACK"}) {
170-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACK/0"_h>>(input, pc));
171-
} else if (description == header::DataDescription{"EXTRACK_IU"}) {
172-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACK_IU/0"_h>>(input, pc));
173-
} else if (description == header::DataDescription{"EXTRACKCOV"}) {
174-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKCOV/0"_h>>(input, pc));
175-
} else if (description == header::DataDescription{"EXTRACKCOV_IU"}) {
176-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKCOV_IU/0"_h>>(input, pc));
177-
} else if (description == header::DataDescription{"EXTRACKEXTRA"}) {
178-
if (version == 0U) {
179-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKEXTRA/0"_h>>(input, pc));
180-
} else if (version == 1U) {
181-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKEXTRA/1"_h>>(input, pc));
182-
} else if (version == 2U) {
183-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKEXTRA/2"_h>>(input, pc));
184-
}
185-
} else if (description == header::DataDescription{"EXMFTTRACK"}) {
186-
if (version == 0U) {
187-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMFTTRACK/0"_h>>(input, pc));
188-
} else if (version == 1U) {
189-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMFTTRACK/1"_h>>(input, pc));
190-
}
191-
} else if (description == header::DataDescription{"EXMFTTRACKCOV"}) {
192-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMFTTRACKCOV/0"_h>>(input, pc));
193-
} else if (description == header::DataDescription{"EXFWDTRACK"}) {
194-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXFWDTRACK/0"_h>>(input, pc));
195-
} else if (description == header::DataDescription{"EXFWDTRACKCOV"}) {
196-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXFWDTRACKCOV/0"_h>>(input, pc));
197-
} else if (description == header::DataDescription{"EXMCPARTICLE"}) {
198-
if (version == 0U) {
199-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMCPARTICLE/0"_h>>(input, pc));
200-
} else if (version == 1U) {
201-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMCPARTICLE/1"_h>>(input, pc));
202-
}
203-
} else {
204-
throw runtime_error("Not an extended table");
205-
}
269+
for (auto& maker : makers) {
270+
outputs.adopt(Output{maker.origin, maker.description, maker.version}, maker.make(pc));
206271
}
207272
};
208273
}};

0 commit comments

Comments
 (0)