Skip to content

Commit bac0d59

Browse files
committed
generalize spawner
1 parent e458133 commit bac0d59

File tree

8 files changed

+200
-55
lines changed

8 files changed

+200
-55
lines changed

Framework/Core/include/Framework/AODReaderHelpers.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
#ifndef O2_FRAMEWORK_AODREADERHELPERS_H_
1313
#define O2_FRAMEWORK_AODREADERHELPERS_H_
1414

15-
#include "Framework/TableBuilder.h"
1615
#include "Framework/AlgorithmSpec.h"
17-
#include "Framework/Logger.h"
18-
#include "Framework/RootMessageContext.h"
1916
#include <uv.h>
2017

2118
namespace o2::framework::readers
@@ -24,7 +21,7 @@ namespace o2::framework::readers
2421

2522
struct AODReaderHelpers {
2623
static AlgorithmSpec rootFileReaderCallback();
27-
static AlgorithmSpec aodSpawnerCallback(std::vector<InputSpec>& requested);
24+
static AlgorithmSpec aodSpawnerCallback(ConfigContext const& ctx);
2825
static AlgorithmSpec indexBuilderCallback(std::vector<InputSpec>& requested);
2926
};
3027

Framework/Core/include/Framework/ASoA.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ struct TableIterator : IP, C... {
12701270

12711271
struct ArrowHelpers {
12721272
static std::shared_ptr<arrow::Table> joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const char* const> labels);
1273+
static std::shared_ptr<arrow::Table> joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const std::string> labels);
12731274
static std::shared_ptr<arrow::Table> concatTables(std::vector<std::shared_ptr<arrow::Table>>&& tables);
12741275
};
12751276

Framework/Core/include/Framework/TableBuilder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "arrow/type_traits.h"
1919

2020
// Apparently needs to be on top of the arrow includes.
21-
#include <sstream>
2221

2322
#include <arrow/chunked_array.h>
2423
#include <arrow/status.h>
@@ -796,6 +795,10 @@ auto makeEmptyTable(const char* name, framework::pack<Cs...> p)
796795
std::shared_ptr<arrow::Table> spawnerHelper(std::shared_ptr<arrow::Table> const& fullTable, std::shared_ptr<arrow::Schema> newSchema, size_t nColumns,
797796
expressions::Projector* projectors, const char* name, std::shared_ptr<gandiva::Projector>& projector);
798797

798+
std::shared_ptr<arrow::Table> spawnerHelper(std::shared_ptr<arrow::Table> const& fullTable, std::shared_ptr<arrow::Schema> newSchema,
799+
const char* name, size_t nColumns,
800+
const std::shared_ptr<gandiva::Projector>& projector);
801+
799802
/// Expression-based column generator to materialize columns
800803
template <aod::is_aod_hash D>
801804
requires(soa::has_configurable_extension<typename o2::aod::MetadataTrait<D>::metadata>)

Framework/Core/src/AODReaderHelpers.cxx

Lines changed: 119 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
#include "Framework/AODReaderHelpers.h"
1313
#include "Framework/AnalysisHelpers.h"
1414
#include "Framework/AnalysisDataModelHelpers.h"
15-
#include "Framework/DataProcessingHelpers.h"
1615
#include "Framework/ExpressionHelpers.h"
16+
#include "Framework/DataProcessingHelpers.h"
1717
#include "Framework/AlgorithmSpec.h"
1818
#include "Framework/ControlService.h"
1919
#include "Framework/CallbackService.h"
2020
#include "Framework/EndOfStreamContext.h"
2121
#include "Framework/DataSpecUtils.h"
22-
#include "Framework/ExpressionJSONHelpers.h"
22+
#include "ExpressionJSONHelpers.h"
23+
#include "Framework/ConfigContext.h"
24+
#include "Framework/AnalysisContext.h"
2325

2426
#include <Monitoring/Monitoring.h>
2527

@@ -136,72 +138,141 @@ auto make_spawn(InputSpec const& input, ProcessingContext& pc)
136138
return o2::framework::spawner<D>(extractOriginals<sources.size(), sources>(pc), input.binding.c_str(), projectors.data(), projector, schema);
137139
}
138140

139-
struct Spawnable {
140-
std::vector<expressions::Projector> projectors;
141+
struct Maker
142+
{
143+
std::string binding;
141144
std::vector<std::string> labels;
145+
std::vector<std::shared_ptr<gandiva::Expression>> expressions;
146+
std::shared_ptr<gandiva::Projector> projector = nullptr;
142147
std::shared_ptr<arrow::Schema> schema;
143148

149+
header::DataOrigin origin;
150+
header::DataDescription description;
151+
header::DataHeader::SubSpecificationType version;
152+
153+
std::shared_ptr<arrow::Table> make(ProcessingContext& pc)
154+
{
155+
std::vector<std::shared_ptr<arrow::Table>> originals;
156+
for (auto const& label : labels) {
157+
originals.push_back(pc.inputs().get<TableConsumer>(label)->asArrowTable());
158+
}
159+
auto fullTable = soa::ArrowHelpers::joinTables(std::move(originals), std::span{labels.begin(), labels.size()});
160+
if (projector == nullptr) {
161+
auto s = gandiva::Projector::Make(
162+
fullTable->schema(),
163+
expressions,
164+
&projector);
165+
if (!s.ok()) {
166+
throw o2::framework::runtime_error_f("Failed to create projector: %s", s.ToString().c_str());
167+
}
168+
}
169+
170+
return spawnerHelper(fullTable, schema, binding.c_str(), schema->num_fields(), projector);
171+
}
172+
173+
};
174+
175+
struct Spawnable {
176+
std::string binding;
177+
std::vector<std::string> labels;
178+
std::vector<expressions::Projector> projectors;
179+
std::vector<std::shared_ptr<gandiva::Expression>> expressions;
180+
std::shared_ptr<arrow::Schema> outputSchema;
181+
std::shared_ptr<arrow::Schema> inputSchema;
182+
183+
header::DataOrigin origin;
184+
header::DataDescription description;
185+
header::DataHeader::SubSpecificationType version;
186+
144187
Spawnable(InputSpec const& spec)
188+
: binding{spec.binding}
145189
{
146-
auto loc = std::find_if(spec.metadata.begin(), spec.metadata.end(), [](ConfigParamSpec const& spc){ return spc.name.compare("projectors") == 0; });
190+
auto&& [origin_, description_, version_] = DataSpecUtils::asConcreteDataMatcher(spec);
191+
origin = origin_;
192+
description = description_;
193+
version = version_;
194+
auto loc = std::find_if(spec.metadata.begin(), spec.metadata.end(), [](ConfigParamSpec const& cps){ return cps.name.compare("projectors") == 0; });
147195
std::stringstream iws(loc->defaultValue.get<std::string>());
148196
projectors = ExpressionJSONHelpers::read(iws);
197+
198+
loc = std::find_if(spec.metadata.begin(), spec.metadata.end(), [](ConfigParamSpec const& cps){ return cps.name.compare("schema") == 0; });
199+
iws.clear();
200+
iws.str(loc->defaultValue.get<std::string>());
201+
outputSchema = ArrowJSONHelpers::read(iws);
202+
149203
for (auto& i : spec.metadata) {
150204
if (i.name.starts_with("input:")) {
151205
labels.emplace_back(i.name.substr(6));
152206
}
153207
}
208+
209+
std::vector<std::shared_ptr<arrow::Field>> fields;
210+
for (auto& p : projectors) {
211+
expressions::walk(p.node.get(),
212+
[&fields](expressions::Node* n) mutable {
213+
if (n->self.index() == 1) {
214+
auto& b = std::get<expressions::BindingNode>(n->self);
215+
if ( std::find_if(fields.begin(), fields.end(), [&b](std::shared_ptr<arrow::Field> const& field){ return field->name() == b.name; }) == fields.end() ) {
216+
fields.emplace_back(std::make_shared<arrow::Field>(b.name, expressions::concreteArrowType(b.type)));
217+
}
218+
}
219+
});
220+
}
221+
inputSchema = std::make_shared<arrow::Schema>(fields);
222+
223+
int i = 0;
224+
for (auto& p : projectors) {
225+
expressions.push_back(
226+
expressions::makeExpression(
227+
expressions::createExpressionTree(
228+
expressions::createOperations(p),
229+
inputSchema),
230+
outputSchema->field(i))
231+
);
232+
++i;
233+
}
234+
}
235+
236+
std::shared_ptr<gandiva::Projector> makeProjector()
237+
{
238+
return expressions::createProjectorHelper(projectors.size(), projectors.data(), inputSchema, outputSchema->fields());
239+
}
240+
241+
Maker createMaker()
242+
{
243+
return {
244+
binding,
245+
labels,
246+
expressions,
247+
nullptr,
248+
outputSchema,
249+
origin,
250+
description,
251+
version
252+
};
154253
}
254+
155255
};
256+
156257
} // namespace
157258

158-
AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(std::vector<InputSpec>& requested)
259+
AlgorithmSpec AODReaderHelpers::aodSpawnerCallback(/*std::vector<InputSpec>& requested*/ ConfigContext const& ctx)
159260
{
160-
return AlgorithmSpec::InitCallback{[requested](InitContext& /*ic*/) {
261+
auto& ac = ctx.services().get<AnalysisContext>();
262+
return AlgorithmSpec::InitCallback{[requested = ac.spawnerInputs](InitContext& /*ic*/) {
161263
std::vector<Spawnable> spawnables;
264+
for (auto& i : requested) {
265+
spawnables.emplace_back(i);
266+
}
267+
std::vector<Maker> makers;
268+
for (auto& s : spawnables) {
269+
makers.push_back(s.createMaker());
270+
}
162271

163-
return [requested, spawnables](ProcessingContext& pc) {
272+
return [makers](ProcessingContext& pc) mutable {
164273
auto outputs = pc.outputs();
165-
// spawn tables
166-
for (auto& input : requested) {
167-
auto&& [origin, description, version] = DataSpecUtils::asConcreteDataMatcher(input);
168-
if (description == header::DataDescription{"EXTRACK"}) {
169-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACK/0"_h>>(input, pc));
170-
} else if (description == header::DataDescription{"EXTRACK_IU"}) {
171-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACK_IU/0"_h>>(input, pc));
172-
} else if (description == header::DataDescription{"EXTRACKCOV"}) {
173-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKCOV/0"_h>>(input, pc));
174-
} else if (description == header::DataDescription{"EXTRACKCOV_IU"}) {
175-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKCOV_IU/0"_h>>(input, pc));
176-
} else if (description == header::DataDescription{"EXTRACKEXTRA"}) {
177-
if (version == 0U) {
178-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKEXTRA/0"_h>>(input, pc));
179-
} else if (version == 1U) {
180-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKEXTRA/1"_h>>(input, pc));
181-
} else if (version == 2U) {
182-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXTRACKEXTRA/2"_h>>(input, pc));
183-
}
184-
} else if (description == header::DataDescription{"EXMFTTRACK"}) {
185-
if (version == 0U) {
186-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMFTTRACK/0"_h>>(input, pc));
187-
} else if (version == 1U) {
188-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMFTTRACK/1"_h>>(input, pc));
189-
}
190-
} else if (description == header::DataDescription{"EXMFTTRACKCOV"}) {
191-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMFTTRACKCOV/0"_h>>(input, pc));
192-
} else if (description == header::DataDescription{"EXFWDTRACK"}) {
193-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXFWDTRACK/0"_h>>(input, pc));
194-
} else if (description == header::DataDescription{"EXFWDTRACKCOV"}) {
195-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXFWDTRACKCOV/0"_h>>(input, pc));
196-
} else if (description == header::DataDescription{"EXMCPARTICLE"}) {
197-
if (version == 0U) {
198-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMCPARTICLE/0"_h>>(input, pc));
199-
} else if (version == 1U) {
200-
outputs.adopt(Output{origin, description, version}, make_spawn<o2::aod::Hash<"EXMCPARTICLE/1"_h>>(input, pc));
201-
}
202-
} else {
203-
throw runtime_error("Not an extended table");
204-
}
274+
for (auto& maker : makers) {
275+
outputs.adopt(Output{maker.origin, maker.description, maker.version}, maker.make(pc));
205276
}
206277
};
207278
}};

Framework/Core/src/ASoA.cxx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,36 @@ std::shared_ptr<arrow::Table> ArrowHelpers::joinTables(std::vector<std::shared_p
9999
return arrow::Table::Make(schema, columns);
100100
}
101101

102+
std::shared_ptr<arrow::Table> ArrowHelpers::joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const std::string> labels)
103+
{
104+
if (tables.size() == 1) {
105+
return tables[0];
106+
}
107+
for (auto i = 0U; i < tables.size() - 1; ++i) {
108+
if (tables[i]->num_rows() != tables[i + 1]->num_rows()) {
109+
throw o2::framework::runtime_error_f("Tables %s and %s have different sizes (%d vs %d) and cannot be joined!",
110+
labels[i].c_str(), labels[i + 1].c_str(), tables[i]->num_rows(), tables[i + 1]->num_rows());
111+
}
112+
}
113+
std::vector<std::shared_ptr<arrow::Field>> fields;
114+
std::vector<std::shared_ptr<arrow::ChunkedArray>> columns;
115+
116+
for (auto& t : tables) {
117+
auto tf = t->fields();
118+
std::copy(tf.begin(), tf.end(), std::back_inserter(fields));
119+
}
120+
121+
auto schema = std::make_shared<arrow::Schema>(fields);
122+
123+
if (tables[0]->num_rows() != 0) {
124+
for (auto& t : tables) {
125+
auto tc = t->columns();
126+
std::copy(tc.begin(), tc.end(), std::back_inserter(columns));
127+
}
128+
}
129+
return arrow::Table::Make(schema, columns);
130+
}
131+
102132
std::shared_ptr<arrow::Table> ArrowHelpers::concatTables(std::vector<std::shared_ptr<arrow::Table>>&& tables)
103133
{
104134
if (tables.size() == 1) {

Framework/Core/src/ArrowSupport.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ o2::framework::ServiceSpec ArrowSupport::arrowBackendSpec()
608608
spawner->inputs.clear();
609609
// replace AlgorithmSpec
610610
// FIXME: it should be made more generic, so it does not need replacement...
611-
spawner->algorithm = readers::AODReaderHelpers::aodSpawnerCallback(ac.spawnerInputs);
611+
spawner->algorithm = readers::AODReaderHelpers::aodSpawnerCallback(ctx);
612612
AnalysisSupportHelpers::addMissingOutputsToSpawner({}, ac.spawnerInputs, ac.requestedAODs, *spawner);
613613
}
614614

Framework/Core/src/TableBuilder.cxx

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,48 @@ std::shared_ptr<arrow::Table> spawnerHelper(std::shared_ptr<arrow::Table> const&
130130
return arrow::Table::Make(newSchema, arrays);
131131
}
132132

133+
std::shared_ptr<arrow::Table> spawnerHelper(std::shared_ptr<arrow::Table> const& fullTable, std::shared_ptr<arrow::Schema> newSchema,
134+
const char* name, size_t nColumns,
135+
std::shared_ptr<gandiva::Projector> const& projector)
136+
{
137+
arrow::TableBatchReader reader(*fullTable);
138+
std::shared_ptr<arrow::RecordBatch> batch;
139+
arrow::ArrayVector v;
140+
std::vector<arrow::ArrayVector> chunks;
141+
chunks.resize(nColumns);
142+
std::vector<std::shared_ptr<arrow::ChunkedArray>> arrays;
143+
144+
while (true) {
145+
auto s = reader.ReadNext(&batch);
146+
if (!s.ok()) {
147+
throw runtime_error_f("Cannot read batches from source table to spawn %s: %s", name, s.ToString().c_str());
148+
}
149+
if (batch == nullptr) {
150+
break;
151+
}
152+
try {
153+
s = projector->Evaluate(*batch, arrow::default_memory_pool(), &v);
154+
if (!s.ok()) {
155+
throw runtime_error_f("Cannot apply projector to source table of %s: %s", name, s.ToString().c_str());
156+
}
157+
} catch (std::exception& e) {
158+
throw runtime_error_f("Cannot apply projector to source table of %s: exception caught: %s", name, e.what());
159+
}
160+
161+
for (auto i = 0U; i < nColumns; ++i) {
162+
chunks[i].emplace_back(v.at(i));
163+
}
164+
}
165+
166+
arrays.reserve(nColumns);
167+
for (auto i = 0U; i < nColumns; ++i) {
168+
arrays.push_back(std::make_shared<arrow::ChunkedArray>(chunks[i]));
169+
}
170+
171+
addLabelToSchema(newSchema, name);
172+
return arrow::Table::Make(newSchema, arrays);
173+
}
174+
133175
} // namespace o2::framework
134176

135177
template class arrow::NumericBuilder<arrow::UInt8Type>;

Framework/Core/src/WorkflowHelpers.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <utility>
3939
#include <vector>
4040
#include <climits>
41+
#include <numeric>
4142

4243
O2_DECLARE_DYNAMIC_LOG(workflow_helpers);
4344

@@ -435,7 +436,7 @@ void WorkflowHelpers::injectServiceDevices(WorkflowSpec& workflow, ConfigContext
435436
"internal-dpl-aod-spawner",
436437
{},
437438
{},
438-
readers::AODReaderHelpers::aodSpawnerCallback(ac.spawnerInputs),
439+
readers::AODReaderHelpers::aodSpawnerCallback(ctx),
439440
{}};
440441
AnalysisSupportHelpers::addMissingOutputsToSpawner({}, ac.spawnerInputs, ac.requestedAODs, aodSpawner);
441442

0 commit comments

Comments
 (0)