Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -1400,10 +1400,10 @@ namespace o2::framework

struct PreslicePolicyBase {
const std::string binding;
Entry bindingKey;
StringPair bindingKey;

bool isMissing() const;
Entry const& getBindingKey() const;
StringPair const& getBindingKey() const;
};

struct PreslicePolicySorted : public PreslicePolicyBase {
Expand All @@ -1428,7 +1428,7 @@ struct PresliceBase : public Policy {
const std::string binding;

PresliceBase(expressions::BindingNode index_)
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
{
}

Expand Down Expand Up @@ -1508,7 +1508,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
{
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
}
}
uint64_t offset = 0;
Expand Down Expand Up @@ -1545,7 +1545,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
{
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
}
}
auto selection = container.getSliceFor(value);
Expand Down Expand Up @@ -1574,7 +1574,7 @@ auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, framework:
{
if constexpr (OPT) {
if (container.isMissing()) {
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.key.c_str());
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.second.c_str());
}
}
uint64_t offset = 0;
Expand Down
14 changes: 5 additions & 9 deletions Framework/Core/include/Framework/AnalysisManagers.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,43 +534,39 @@ static void setGroupedCombination(C& comb, TG& grouping, std::tuple<Ts...>& asso
/// Preslice handling
template <typename T>
requires(!is_preslice<T>)
bool registerCache(T&, Cache&, Cache&)
bool registerCache(T&, std::vector<StringPair>&, std::vector<StringPair>&)
{
return false;
}

template <is_preslice T>
requires std::same_as<typename T::policy_t, framework::PreslicePolicySorted>
bool registerCache(T& preslice, Cache& bsks, Cache&)
bool registerCache(T& preslice, std::vector<StringPair>& bsks, std::vector<StringPair>&)
{
if constexpr (T::optional) {
if (preslice.binding == "[MISSING]") {
return true;
}
}
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
if (locate == bsks.end()) {
bsks.emplace_back(preslice.getBindingKey());
} else if (locate->enabled == false) {
locate->enabled = true;
}
return true;
}

template <is_preslice T>
requires std::same_as<typename T::policy_t, framework::PreslicePolicyGeneral>
bool registerCache(T& preslice, Cache&, Cache& bsksU)
bool registerCache(T& preslice, std::vector<StringPair>&, std::vector<StringPair>& bsksU)
{
if constexpr (T::optional) {
if (preslice.binding == "[MISSING]") {
return true;
}
}
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
if (locate == bsksU.end()) {
bsksU.emplace_back(preslice.getBindingKey());
} else if (locate->enabled == false) {
locate->enabled = true;
}
return true;
}
Expand Down
22 changes: 11 additions & 11 deletions Framework/Core/include/Framework/AnalysisTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@ concept is_enumeration = is_enumeration_v<std::decay_t<T>>;
namespace {
struct AnalysisDataProcessorBuilder {
template <typename G, typename... Args>
static void addGroupingCandidates(Cache& bk, Cache& bku, bool enabled)
static void addGroupingCandidates(std::vector<StringPair>& bk, std::vector<StringPair>& bku)
{
[&bk, &bku, enabled]<typename... As>(framework::pack<As...>) mutable {
[&bk, &bku]<typename... As>(framework::pack<As...>) mutable {
std::string key;
if constexpr (soa::is_iterator<std::decay_t<G>>) {
key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType<std::decay_t<G>>());
}
([&bk, &bku, &key, enabled]() mutable {
([&bk, &bku, &key]() mutable {
if constexpr (soa::relatedByIndex<std::decay_t<G>, std::decay_t<As>>()) {
auto binding = soa::getLabelFromTypeForKey<std::decay_t<As>>(key);
if constexpr (o2::soa::is_smallgroups<std::decay_t<As>>) {
framework::updatePairList(bku, binding, key, enabled);
framework::updatePairList(bku, binding, key);
} else {
framework::updatePairList(bk, binding, key, enabled);
framework::updatePairList(bk, binding, key);
}
}
}(),
Expand Down Expand Up @@ -147,7 +147,7 @@ struct AnalysisDataProcessorBuilder {
/// helper to parse the process arguments
/// 1. enumeration (must be the only argument)
template <typename R, typename C, is_enumeration A>
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, Cache&, Cache&)
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, std::vector<StringPair>&, std::vector<StringPair>&)
{
std::vector<ConfigParamSpec> inputMetadata;
// FIXME: for the moment we do not support begin, end and step.
Expand All @@ -156,17 +156,17 @@ struct AnalysisDataProcessorBuilder {

/// 2. grouping case - 1st argument is an iterator
template <typename R, typename C, soa::is_iterator A, soa::is_table... Args>
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache& bk, Cache& bku)
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>& bk, std::vector<StringPair>& bku)
requires(std::is_lvalue_reference_v<A> && (std::is_lvalue_reference_v<Args> && ...))
{
addGroupingCandidates<A, Args...>(bk, bku, value);
addGroupingCandidates<A, Args...>(bk, bku);
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(A, Args...)>();
addInputsAndExpressions<typename std::decay_t<A>::parent_t, Args...>(hash, name, value, inputs, eInfos);
}

/// 3. generic case
template <typename R, typename C, soa::is_table... Args>
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache&, Cache&)
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>&, std::vector<StringPair>&)
requires(std::is_lvalue_reference_v<Args> && ...)
{
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(Args...)>();
Expand Down Expand Up @@ -480,8 +480,8 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
std::vector<InputSpec> inputs;
std::vector<ConfigParamSpec> options;
std::vector<ExpressionInfo> expressionInfos;
Cache bindingsKeys;
Cache bindingsKeysUnsorted;
std::vector<StringPair> bindingsKeys;
std::vector<StringPair> bindingsKeysUnsorted;

/// make sure options and configurables are set before expression infos are created
homogeneous_apply_refs([&options, &hash](auto& element) { return analysis_task_parsers::appendOption(options, element); }, *task.get());
Expand Down
45 changes: 16 additions & 29 deletions Framework/Core/include/Framework/ArrowTableSlicingCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,64 +34,51 @@ struct SliceInfoUnsortedPtr {
gsl::span<int64_t const> getSliceFor(int value) const;
};

struct Entry {
std::string binding;
std::string key;
bool enabled;

Entry(std::string b, std::string k, bool e = true)
: binding{b},
key{k},
enabled{e}
{
}
};

using Cache = std::vector<Entry>;
using StringPair = std::pair<std::string, std::string>;

void updatePairList(Cache& list, std::string const& binding, std::string const& key, bool enabled);
void updatePairList(std::vector<StringPair>& list, std::string const& binding, std::string const& key);

struct ArrowTableSlicingCacheDef {
constexpr static ServiceKind service_kind = ServiceKind::Global;
Cache bindingsKeys;
Cache bindingsKeysUnsorted;
std::vector<StringPair> bindingsKeys;
std::vector<StringPair> bindingsKeysUnsorted;

void setCaches(Cache&& bsks);
void setCachesUnsorted(Cache&& bsks);
void setCaches(std::vector<StringPair>&& bsks);
void setCachesUnsorted(std::vector<StringPair>&& bsks);
};

struct ArrowTableSlicingCache {
constexpr static ServiceKind service_kind = ServiceKind::Stream;

Cache bindingsKeys;
std::vector<StringPair> bindingsKeys;
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;

Cache bindingsKeysUnsorted;
std::vector<StringPair> bindingsKeysUnsorted;
std::vector<std::vector<int>> valuesUnsorted;
std::vector<ListVector> groups;

ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {});
ArrowTableSlicingCache(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});

// set caching information externally
void setCaches(Cache&& bsks, Cache&& bsksUnsorted = {});
void setCaches(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});

// update slicing info cache entry (assumes it is already present)
arrow::Status updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table);
arrow::Status updateCacheEntryUnsorted(int pos, std::shared_ptr<arrow::Table> const& table);

// helper to locate cache position
std::pair<int, bool> getCachePos(Entry const& bindingKey) const;
int getCachePosSortedFor(Entry const& bindingKey) const;
int getCachePosUnsortedFor(Entry const& bindingKey) const;
std::pair<int, bool> getCachePos(StringPair const& bindingKey) const;
int getCachePosSortedFor(StringPair const& bindingKey) const;
int getCachePosUnsortedFor(StringPair const& bindingKey) const;

// get slice from cache for a given value
SliceInfoPtr getCacheFor(Entry const& bindingKey) const;
SliceInfoUnsortedPtr getCacheUnsortedFor(Entry const& bindingKey) const;
SliceInfoPtr getCacheFor(StringPair const& bindingKey) const;
SliceInfoUnsortedPtr getCacheUnsortedFor(StringPair const& bindingKey) const;
SliceInfoPtr getCacheForPos(int pos) const;
SliceInfoUnsortedPtr getCacheUnsortedForPos(int pos) const;

static void validateOrder(Entry const& bindingKey, std::shared_ptr<arrow::Table> const& input);
static void validateOrder(StringPair const& bindingKey, std::shared_ptr<arrow::Table> const& input);
};
} // namespace o2::framework

Expand Down
2 changes: 1 addition & 1 deletion Framework/Core/include/Framework/GroupSlicer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct GroupSlicer {
{
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
auto binding = o2::soa::getLabelFromTypeForKey<std::decay_t<T>>(mIndexColumnName);
auto bk = Entry(binding, mIndexColumnName);
auto bk = std::make_pair(binding, mIndexColumnName);
if constexpr (!o2::soa::is_smallgroups<std::decay_t<T>>) {
if (table.size() == 0) {
return;
Expand Down
2 changes: 1 addition & 1 deletion Framework/Core/src/ASoA.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ bool PreslicePolicyBase::isMissing() const
return binding == "[MISSING]";
}

Entry const& PreslicePolicyBase::getBindingKey() const
StringPair const& PreslicePolicyBase::getBindingKey() const
{
return bindingKey;
}
Expand Down
19 changes: 9 additions & 10 deletions Framework/Core/src/ArrowSupport.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -567,27 +567,26 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec()
.name = "arrow-slicing-cache",
.uniqueId = CommonServices::simpleServiceId<ArrowTableSlicingCache>(),
.init = [](ServiceRegistryRef services, DeviceState&, fair::mq::ProgOptions&) { return ServiceHandle{TypeIdHelpers::uniqueId<ArrowTableSlicingCache>(),
new ArrowTableSlicingCache(Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeys},
Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
new ArrowTableSlicingCache(std::vector<std::pair<std::string, std::string>>{services.get<ArrowTableSlicingCacheDef>().bindingsKeys}, std::vector{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
ServiceKind::Stream, typeid(ArrowTableSlicingCache).name()}; },
.configure = CommonServices::noConfiguration(),
.preProcessing = [](ProcessingContext& pc, void* service_ptr) {
auto* service = static_cast<ArrowTableSlicingCache*>(service_ptr);
auto& caches = service->bindingsKeys;
for (auto i = 0u; i < caches.size(); ++i) {
if (caches[i].enabled && pc.inputs().getPos(caches[i].binding.c_str()) >= 0) {
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].binding.c_str())->asArrowTable());
for (auto i = 0; i < caches.size(); ++i) {
if (pc.inputs().getPos(caches[i].first.c_str()) >= 0) {
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].first.c_str())->asArrowTable());
if (!status.ok()) {
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].binding.c_str(), caches[i].key.c_str());
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].first.c_str(), caches[i].second.c_str());
}
}
}
auto& unsortedCaches = service->bindingsKeysUnsorted;
for (auto i = 0u; i < unsortedCaches.size(); ++i) {
if (unsortedCaches[i].enabled && pc.inputs().getPos(unsortedCaches[i].binding.c_str()) >= 0) {
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].binding.c_str())->asArrowTable());
for (auto i = 0; i < unsortedCaches.size(); ++i) {
if (pc.inputs().getPos(unsortedCaches[i].first.c_str()) >= 0) {
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].first.c_str())->asArrowTable());
if (!status.ok()) {
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].binding.c_str(), unsortedCaches[i].key.c_str());
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].first.c_str(), unsortedCaches[i].second.c_str());
}
}
} },
Expand Down
Loading