Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ struct TableIterator : IP, C... {
};

struct ArrowHelpers {
static std::shared_ptr<arrow::Table> joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables);
static std::shared_ptr<arrow::Table> joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const char* const> labels);
static std::shared_ptr<arrow::Table> concatTables(std::vector<std::shared_ptr<arrow::Table>>&& tables);
};

Expand Down Expand Up @@ -1683,6 +1683,7 @@ class Table
using table_t = self_t;

static constexpr const auto originals = computeOriginals<ref, Ts...>();
static constexpr const auto originalLabels = []<size_t N, std::array<TableRef, N> refs, size_t... Is>(std::index_sequence<Is...>) { return std::array<const char*, N>{o2::aod::label<refs[Is]>()...}; }.template operator()<originals.size(), originals>(std::make_index_sequence<originals.size()>());

template <size_t N, std::array<TableRef, N> bindings>
requires(ref.origin_hash == "CONC"_h)
Expand Down Expand Up @@ -1931,7 +1932,7 @@ class Table

Table(std::vector<std::shared_ptr<arrow::Table>>&& tables, uint64_t offset = 0)
requires(ref.origin_hash != "CONC"_h)
: Table(ArrowHelpers::joinTables(std::move(tables)), offset)
: Table(ArrowHelpers::joinTables(std::move(tables), std::span{originalLabels}), offset)
{
}

Expand Down Expand Up @@ -3213,7 +3214,7 @@ struct JoinFull : Table<o2::aod::Hash<"JOIN"_h>, D, o2::aod::Hash<"JOIN"_h>, Ts.
bindInternalIndicesTo(this);
}
JoinFull(std::vector<std::shared_ptr<arrow::Table>>&& tables, uint64_t offset = 0)
: base{ArrowHelpers::joinTables(std::move(tables)), offset}
: base{ArrowHelpers::joinTables(std::move(tables), std::span{base::originalLabels}), offset}
{
bindInternalIndicesTo(this);
}
Expand All @@ -3223,6 +3224,7 @@ struct JoinFull : Table<o2::aod::Hash<"JOIN"_h>, D, o2::aod::Hash<"JOIN"_h>, Ts.
using self_t = JoinFull<D, Ts...>;
using table_t = base;
static constexpr const auto originals = base::originals;
static constexpr const auto originalLabels = base::originalLabels;
using columns_t = typename table_t::columns_t;
using persistent_columns_t = typename table_t::persistent_columns_t;
using iterator = table_t::template iterator_template<DefaultIndexPolicy, self_t, Ts...>;
Expand Down Expand Up @@ -3293,7 +3295,7 @@ using Join = JoinFull<o2::aod::Hash<"JOIN/0"_h>, Ts...>;
template <typename... Ts>
constexpr auto join(Ts const&... t)
{
return Join<Ts...>(ArrowHelpers::joinTables({t.asArrowTable()...}));
return Join<Ts...>(ArrowHelpers::joinTables({t.asArrowTable()...}, std::span{Join<Ts...>::base::originalLabels}));
}

template <typename T>
Expand Down
8 changes: 4 additions & 4 deletions Framework/Core/include/Framework/AnalysisManagers.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,14 @@ template <is_spawns T>
bool prepareOutput(ProcessingContext& context, T& spawns)
{
using metadata = o2::aod::MetadataTrait<o2::aod::Hash<T::spawnable_t::ref.desc_hash>>::metadata;
auto originalTable = soa::ArrowHelpers::joinTables(extractOriginals<metadata::sources.size(), metadata::sources>(context));
auto originalTable = soa::ArrowHelpers::joinTables(extractOriginals<metadata::sources.size(), metadata::sources>(context), std::span{metadata::base_table_t::originalLabels});
if (originalTable->schema()->fields().empty() == true) {
using base_table_t = typename T::base_table_t::table_t;
originalTable = makeEmptyTable<base_table_t>(o2::aod::label<metadata::extension_table_t::ref>());
}

spawns.extension = std::make_shared<typename T::extension_t>(o2::framework::spawner<o2::aod::Hash<metadata::extension_table_t::ref.desc_hash>>(originalTable, o2::aod::label<metadata::extension_table_t::ref>(), spawns.projector));
spawns.table = std::make_shared<typename T::spawnable_t::table_t>(soa::ArrowHelpers::joinTables({spawns.extension->asArrowTable(), originalTable}));
spawns.table = std::make_shared<typename T::spawnable_t::table_t>(soa::ArrowHelpers::joinTables({spawns.extension->asArrowTable(), originalTable}, std::span{T::spawnable_t::table_t::originalLabels}));
return true;
}

Expand All @@ -304,14 +304,14 @@ template <is_defines T>
bool prepareOutput(ProcessingContext& context, T& defines)
{
using metadata = o2::aod::MetadataTrait<o2::aod::Hash<T::spawnable_t::ref.desc_hash>>::metadata;
auto originalTable = soa::ArrowHelpers::joinTables(extractOriginals<metadata::sources.size(), metadata::sources>(context));
auto originalTable = soa::ArrowHelpers::joinTables(extractOriginals<metadata::sources.size(), metadata::sources>(context), std::span{metadata::base_table_t::originalLabels});
if (originalTable->schema()->fields().empty() == true) {
using base_table_t = typename T::base_table_t::table_t;
originalTable = makeEmptyTable<base_table_t>(o2::aod::label<metadata::extension_table_t::ref>());
}

defines.extension = std::make_shared<typename T::extension_t>(o2::framework::spawner<o2::aod::Hash<metadata::extension_table_t::ref.desc_hash>>(originalTable, o2::aod::label<metadata::extension_table_t::ref>(), defines.projectors.data(), defines.projector));
defines.table = std::make_shared<typename T::spawnable_t::table_t>(soa::ArrowHelpers::joinTables({defines.extension->asArrowTable(), originalTable}));
defines.table = std::make_shared<typename T::spawnable_t::table_t>(soa::ArrowHelpers::joinTables({defines.extension->asArrowTable(), originalTable}, std::span{T::spawnable_t::table_t::originalLabels}));
return true;
}

Expand Down
4 changes: 2 additions & 2 deletions Framework/Core/include/Framework/AnalysisTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ struct AnalysisDataProcessorBuilder {
std::shared_ptr<arrow::Table> table = nullptr;
auto joiner = [&record]<size_t N, std::array<soa::TableRef, N> refs, size_t... Is>(std::index_sequence<Is...>) { return std::vector{extractTableFromRecord<refs[Is]>(record)...}; };
if constexpr (soa::is_iterator<T>) {
table = o2::soa::ArrowHelpers::joinTables(joiner.template operator()<T::parent_t::originals.size(), T::parent_t::originals>(std::make_index_sequence<T::parent_t::originals.size()>()));
table = o2::soa::ArrowHelpers::joinTables(joiner.template operator()<T::parent_t::originals.size(), T::parent_t::originals>(std::make_index_sequence<T::parent_t::originals.size()>()), std::span{T::parent_t::originalLabels});
} else {
table = o2::soa::ArrowHelpers::joinTables(joiner.template operator()<T::originals.size(), T::originals>(std::make_index_sequence<T::originals.size()>()));
table = o2::soa::ArrowHelpers::joinTables(joiner.template operator()<T::originals.size(), T::originals>(std::make_index_sequence<T::originals.size()>()), std::span{T::originalLabels});
}
expressions::updateFilterInfo(info, table);
if constexpr (!o2::soa::is_smallgroups<std::decay_t<T>>) {
Expand Down
7 changes: 4 additions & 3 deletions Framework/Core/include/Framework/TableBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ template <aod::is_aod_hash D>
auto spawner(std::vector<std::shared_ptr<arrow::Table>>&& tables, const char* name, o2::framework::expressions::Projector* projectors, std::shared_ptr<gandiva::Projector>& projector)
{
using placeholders_pack_t = typename o2::aod::MetadataTrait<D>::metadata::placeholders_pack_t;
auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables));
auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables), std::span{o2::aod::MetadataTrait<D>::metadata::base_table_t::originalLabels});
if (fullTable->num_rows() == 0) {
return makeEmptyTable(name, placeholders_pack_t{});
}
Expand Down Expand Up @@ -892,7 +892,7 @@ template <aod::is_aod_hash D>
auto spawner(std::vector<std::shared_ptr<arrow::Table>>&& tables, const char* name, std::shared_ptr<gandiva::Projector>& projector)
{
using expression_pack_t = typename o2::aod::MetadataTrait<D>::metadata::expression_pack_t;
auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables));
auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables), std::span{o2::aod::MetadataTrait<D>::metadata::base_table_t::originalLabels});
if (fullTable->num_rows() == 0) {
return makeEmptyTable(name, expression_pack_t{});
}
Expand Down Expand Up @@ -929,7 +929,8 @@ auto spawner(std::shared_ptr<arrow::Table> const& fullTable, const char* name, s
template <typename... C>
auto spawner(framework::pack<C...> columns, std::vector<std::shared_ptr<arrow::Table>>&& tables, const char* name, std::shared_ptr<gandiva::Projector>& projector)
{
auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables));
std::array<const char*, 1> labels{"original"};
auto fullTable = soa::ArrowHelpers::joinTables(std::move(tables), std::span<const char* const>{labels});
if (fullTable->num_rows() == 0) {
return makeEmptyTable(name, framework::pack<C...>{});
}
Expand Down
7 changes: 2 additions & 5 deletions Framework/Core/src/ASoA.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,15 @@ SelectionVector sliceSelection(gsl::span<int64_t const> const& mSelectedRows, in
return slicedSelection;
}

std::shared_ptr<arrow::Table> ArrowHelpers::joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables)
std::shared_ptr<arrow::Table> ArrowHelpers::joinTables(std::vector<std::shared_ptr<arrow::Table>>&& tables, std::span<const char* const> labels)
{
if (tables.size() == 1) {
return tables[0];
}
for (auto i = 0U; i < tables.size() - 1; ++i) {
if (tables[i]->num_rows() != tables[i + 1]->num_rows()) {
throw o2::framework::runtime_error_f("Tables %s and %s have different sizes (%d vs %d) and cannot be joined!",
tables[i]->schema()->metadata()->Get("label").ValueOrDie().c_str(),
tables[i + 1]->schema()->metadata()->Get("label").ValueOrDie().c_str(),
tables[i]->num_rows(),
tables[i + 1]->num_rows());
labels[i], labels[i + 1], tables[i]->num_rows(), tables[i + 1]->num_rows());
}
}
std::vector<std::shared_ptr<arrow::Field>> fields;
Expand Down
15 changes: 15 additions & 0 deletions Framework/Core/test/test_ASoA.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace test
DECLARE_SOA_COLUMN(X, x, int);
DECLARE_SOA_COLUMN(Y, y, int);
DECLARE_SOA_COLUMN(Z, z, int);
DECLARE_SOA_COLUMN(W, w, int);
DECLARE_SOA_DYNAMIC_COLUMN(Sum, sum, [](int x, int y) { return x + y; });
DECLARE_SOA_EXPRESSION_COLUMN(ESum, esum, int, test::x + test::y);
} // namespace test
Expand Down Expand Up @@ -268,9 +269,17 @@ TEST_CASE("TestJoinedTables")
rowWriterZ(0, 8);
auto tableZ = builderZ.finalize();

TableBuilder builderW;
auto rowWriterW = builderW.persist<int32_t>({"fW"});
rowWriterW(0, 8);
rowWriterW(0, 8);
rowWriterW(0, 8);
auto tableW = builderW.finalize();

using TestX = InPlaceTable<"A0"_h, o2::aod::test::X>;
using TestY = InPlaceTable<"A1"_h, o2::aod::test::Y>;
using TestZ = InPlaceTable<"A2"_h, o2::aod::test::Z>;
using TestW = InPlaceTable<"A3"_h, o2::aod::test::W>;
using Test = Join<TestX, TestY>;

REQUIRE(Test::contains<TestX>());
Expand Down Expand Up @@ -303,6 +312,12 @@ TEST_CASE("TestJoinedTables")
for (auto& test : tests4) {
REQUIRE(15 == test.x() + test.y() + test.z());
}

try {
auto testF = join(TestZ{tableZ}, TestW{tableW});
} catch (RuntimeErrorRef ref) {
REQUIRE(std::string{error_from_ref(ref).what} == "Tables TEST and TEST have different sizes (8 vs 3) and cannot be joined!");
}
}

TEST_CASE("TestConcatTables")
Expand Down