Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 126 additions & 139 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,76 +1389,69 @@ consteval static bool relatedBySortedIndex()

namespace o2::framework
{
template <typename T, bool OPT = false, bool SORTED = true>
struct PresliceBase {
constexpr static bool sorted = SORTED;

struct PreslicePolicyBase {
const std::string binding;
StringPair bindingKey;

bool isMissing() const;
StringPair const& getBindingKey() const;
};

struct PreslicePolicySorted : public PreslicePolicyBase {
void updateSliceInfo(SliceInfoPtr&& si);

SliceInfoPtr sliceInfo;
std::shared_ptr<arrow::Table> getSliceFor(int value, std::shared_ptr<arrow::Table> const& input, uint64_t& offset) const;
};

struct PreslicePolicyGeneral : public PreslicePolicyBase {
void updateSliceInfo(SliceInfoUnsortedPtr&& si);

SliceInfoUnsortedPtr sliceInfo;
gsl::span<const int64_t> getSliceFor(int value) const;
};

template <typename T, typename Policy, bool OPT = false>
struct PresliceBase : public Policy {
constexpr static bool optional = OPT;
using target_t = T;
const std::string binding;

PresliceBase(expressions::BindingNode index_)
: binding{o2::soa::getLabelFromTypeForKey<T, OPT>(index_.name)},
bindingKey{binding, index_.name} {}

void updateSliceInfo(std::conditional_t<SORTED, SliceInfoPtr, SliceInfoUnsortedPtr>&& si)
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
{
sliceInfo = si;
}

std::shared_ptr<arrow::Table> getSliceFor(int value, std::shared_ptr<arrow::Table> const& input, uint64_t& offset) const
{
if constexpr (OPT) {
if (isMissing()) {
if (Policy::isMissing()) {
return nullptr;
}
}
if constexpr (SORTED) {
auto [offset_, count] = sliceInfo.getSliceFor(value);
auto output = input->Slice(offset_, count);
offset = static_cast<int64_t>(offset_);
return output;
} else {
static_assert(SORTED, "Wrong method called for unsorted cache");
}
return Policy::getSliceFor(value, input, offset);
}

gsl::span<const int64_t> getSliceFor(int value) const
{
if constexpr (OPT) {
if (isMissing()) {
if (Policy::isMissing()) {
return {};
}
}
if constexpr (!SORTED) {
return sliceInfo.getSliceFor(value);
} else {
static_assert(!SORTED, "Wrong method called for sorted cache");
}
return Policy::getSliceFor(value);
}

bool isMissing() const
{
return binding == "[MISSING]";
}

StringPair const& getBindingKey() const
{
return bindingKey;
}

std::conditional_t<SORTED, SliceInfoPtr, SliceInfoUnsortedPtr> sliceInfo;

StringPair bindingKey;
};

template <typename T>
using PresliceUnsorted = PresliceBase<T, false, false>;
using PresliceUnsorted = PresliceBase<T, PreslicePolicyGeneral, false>;
template <typename T>
using PresliceUnsortedOptional = PresliceBase<T, true, false>;
using PresliceUnsortedOptional = PresliceBase<T, PreslicePolicyGeneral, true>;
template <typename T>
using Preslice = PresliceBase<T, false, true>;
using Preslice = PresliceBase<T, PreslicePolicySorted, false>;
template <typename T>
using PresliceOptional = PresliceBase<T, true, true>;
using PresliceOptional = PresliceBase<T, PreslicePolicySorted, true>;

} // namespace o2::framework

Expand Down Expand Up @@ -1497,96 +1490,84 @@ static consteval auto extractBindings(framework::pack<Is...>)

SelectionVector selectionToVector(gandiva::Selection const& sel);

template <typename T, typename C, bool OPT, bool SORTED>
auto doSliceBy(T const* table, o2::framework::PresliceBase<C, OPT, SORTED> const& container, int value)
template <typename T, typename C, typename Policy, bool OPT>
requires std::same_as<Policy, framework::PreslicePolicySorted> && (o2::soa::is_binding_compatible_v<C, T>())
auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const& container, int value)
{
if constexpr (o2::soa::is_binding_compatible_v<C, T>()) {
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
}
}
if constexpr (SORTED) {
uint64_t offset = 0;
auto out = container.getSliceFor(value, table->asArrowTable(), offset);
auto t = typename T::self_t({out}, offset);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
return t;
} else {
auto selection = container.getSliceFor(value);
if constexpr (soa::is_filtered_table<T>) {
auto t = soa::Filtered<typename T::base_t>({table->asArrowTable()}, selection);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
t.intersectWithSelection(table->getSelectedRows()); // intersect filters
return t;
} else {
auto t = soa::Filtered<T>({table->asArrowTable()}, selection);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
return t;
}
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
}
} else {
if constexpr (SORTED) {
static_assert(o2::framework::always_static_assert_v<C>, "Wrong Preslice<> entry used: incompatible type");
} else {
static_assert(o2::framework::always_static_assert_v<C>, "Wrong PresliceUnsorted<> entry used: incompatible type");
}
uint64_t offset = 0;
auto out = container.getSliceFor(value, table->asArrowTable(), offset);
auto t = typename T::self_t({out}, offset);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
return t;
}

template <soa::is_filtered_table T>
auto doSliceByHelper(T const* table, gsl::span<const int64_t> const& selection)
{
auto t = soa::Filtered<typename T::base_t>({table->asArrowTable()}, selection);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
t.intersectWithSelection(table->getSelectedRows()); // intersect filters
return t;
}

template <soa::is_table T>
requires(!soa::is_filtered_table<T>)
auto doSliceByHelper(T const* table, gsl::span<const int64_t> const& selection)
{
auto t = soa::Filtered<T>({table->asArrowTable()}, selection);
table->copyIndexBindings(t);
t.bindInternalIndicesTo(table);
return t;
}

template <typename T, typename C, typename Policy, bool OPT>
requires std::same_as<Policy, framework::PreslicePolicyGeneral> && (o2::soa::is_binding_compatible_v<C, T>())
auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const& container, int value)
{
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
}
}
auto selection = container.getSliceFor(value);
return doSliceByHelper(table, selection);
}

template <typename T>
SelectionVector sliceSelection(gsl::span<int64_t const> const& mSelectedRows, int64_t nrows, uint64_t offset);

template <soa::is_filtered_table T>
auto prepareFilteredSlice(T const* table, std::shared_ptr<arrow::Table> slice, uint64_t offset)
{
if (offset >= static_cast<uint64_t>(table->tableSize())) {
if constexpr (soa::is_filtered_table<T>) {
Filtered<typename T::base_t> fresult{{{slice}}, SelectionVector{}, 0};
table->copyIndexBindings(fresult);
return fresult;
} else {
typename T::self_t fresult{{{slice}}, SelectionVector{}, 0};
table->copyIndexBindings(fresult);
return fresult;
}
}
auto start = offset;
auto end = start + slice->num_rows();
auto mSelectedRows = table->getSelectedRows();
auto start_iterator = std::lower_bound(mSelectedRows.begin(), mSelectedRows.end(), start);
auto stop_iterator = std::lower_bound(start_iterator, mSelectedRows.end(), end);
SelectionVector slicedSelection{start_iterator, stop_iterator};
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
[&start](int64_t idx) {
return idx - static_cast<int64_t>(start);
});
if constexpr (soa::is_filtered_table<T>) {
Filtered<typename T::base_t> fresult{{{slice}}, std::move(slicedSelection), start};
table->copyIndexBindings(fresult);
return fresult;
} else {
typename T::self_t fresult{{{slice}}, std::move(slicedSelection), start};
Filtered<typename T::base_t> fresult{{{slice}}, SelectionVector{}, 0};
table->copyIndexBindings(fresult);
return fresult;
}
auto slicedSelection = sliceSelection(table->getSelectedRows(), slice->num_rows(), offset);
Filtered<typename T::base_t> fresult{{{slice}}, std::move(slicedSelection), offset};
table->copyIndexBindings(fresult);
return fresult;
}

template <typename T, typename C, bool OPT>
auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, OPT> const& container, int value)
template <soa::is_filtered_table T, typename C, bool OPT>
requires(o2::soa::is_binding_compatible_v<C, T>())
auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, framework::PreslicePolicySorted, OPT> const& container, int value)
{
if constexpr (o2::soa::is_binding_compatible_v<C, T>()) {
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.second.c_str());
}
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.second.c_str());
}
uint64_t offset = 0;
auto slice = container.getSliceFor(value, table->asArrowTable(), offset);
return prepareFilteredSlice(table, slice, offset);
} else {
static_assert(o2::framework::always_static_assert_v<C>, "Wrong Preslice<> entry used: incompatible type");
}
uint64_t offset = 0;
auto slice = container.getSliceFor(value, table->asArrowTable(), offset);
return prepareFilteredSlice(table, slice, offset);
}

template <typename T>
Expand Down Expand Up @@ -2099,8 +2080,8 @@ class Table
return doSliceByCachedUnsorted(this, node, value, cache);
}

template <typename T1, bool OPT, bool SORTED>
auto sliceBy(o2::framework::PresliceBase<T1, OPT, SORTED> const& container, int value) const
template <typename T1, typename Policy, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, Policy, OPT> const& container, int value) const
{
return doSliceBy(this, container, value);
}
Expand Down Expand Up @@ -3201,8 +3182,8 @@ struct JoinFull : Table<o2::aod::Hash<"JOIN"_h>, D, o2::aod::Hash<"JOIN"_h>, Ts.
return doSliceByCachedUnsorted(this, node, value, cache);
}

template <typename T1, bool OPT, bool SORTED>
auto sliceBy(o2::framework::PresliceBase<T1, OPT, SORTED> const& container, int value) const
template <typename T1, typename Policy, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, Policy, OPT> const& container, int value) const
{
return doSliceBy(this, container, value);
}
Expand Down Expand Up @@ -3463,14 +3444,16 @@ class FilteredBase : public T
return doSliceByCachedUnsorted(this, node, value, cache);
}

template <typename T1, bool OPT, bool SORTED>
auto sliceBy(o2::framework::PresliceBase<T1, OPT, SORTED> const& container, int value) const
template <typename T1, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, framework::PreslicePolicySorted, OPT> const& container, int value) const
{
if constexpr (SORTED) {
return doFilteredSliceBy(this, container, value);
} else {
return doSliceBy(this, container, value);
}
return doFilteredSliceBy(this, container, value);
}

template <typename T1, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, framework::PreslicePolicyGeneral, OPT> const& container, int value) const
{
return doSliceBy(this, container, value);
}

auto select(framework::expressions::Filter const& f) const
Expand Down Expand Up @@ -3697,14 +3680,16 @@ class Filtered : public FilteredBase<T>
return doSliceByCachedUnsorted(this, node, value, cache);
}

template <typename T1, bool OPT, bool SORTED>
auto sliceBy(o2::framework::PresliceBase<T1, OPT, SORTED> const& container, int value) const
template <typename T1, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, framework::PreslicePolicySorted, OPT> const& container, int value) const
{
if constexpr (SORTED) {
return doFilteredSliceBy(this, container, value);
} else {
return doSliceBy(this, container, value);
}
return doFilteredSliceBy(this, container, value);
}

template <typename T1, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, framework::PreslicePolicyGeneral, OPT> const& container, int value) const
{
return doSliceBy(this, container, value);
}

auto select(framework::expressions::Filter const& f) const
Expand Down Expand Up @@ -3864,14 +3849,16 @@ class Filtered<Filtered<T>> : public FilteredBase<typename T::table_t>
return doSliceByCachedUnsorted(this, node, value, cache);
}

template <typename T1, bool OPT, bool SORTED>
auto sliceBy(o2::framework::PresliceBase<T1, OPT, SORTED> const& container, int value) const
template <typename T1, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, framework::PreslicePolicySorted, OPT> const& container, int value) const
{
if constexpr (SORTED) {
return doFilteredSliceBy(this, container, value);
} else {
return doSliceBy(this, container, value);
}
return doFilteredSliceBy(this, container, value);
}

template <typename T1, bool OPT>
auto sliceBy(o2::framework::PresliceBase<T1, framework::PreslicePolicyGeneral, OPT> const& container, int value) const
{
return doSliceBy(this, container, value);
}

private:
Expand Down
4 changes: 2 additions & 2 deletions Framework/Core/include/Framework/AnalysisHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ struct Partition {
return mFiltered->sliceByCachedUnsorted(node, value, cache);
}

template <typename T1, bool OPT, bool SORTED>
[[nodiscard]] auto sliceBy(o2::framework::PresliceBase<T1, OPT, SORTED> const& container, int value) const
template <typename T1, typename Policy, bool OPT>
[[nodiscard]] auto sliceBy(o2::framework::PresliceBase<T1, Policy, OPT> const& container, int value) const
{
return mFiltered->sliceBy(container, value);
}
Expand Down
Loading
Loading