Skip to content

Commit 15b4f5f

Browse files
authored
DPL Analysis: prevent slice cache from updating when not required by enabled process functions (#14057)
1 parent 69f1fd1 commit 15b4f5f

File tree

9 files changed

+96
-73
lines changed

9 files changed

+96
-73
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,10 +1400,10 @@ namespace o2::framework
14001400

14011401
struct PreslicePolicyBase {
14021402
const std::string binding;
1403-
StringPair bindingKey;
1403+
Entry bindingKey;
14041404

14051405
bool isMissing() const;
1406-
StringPair const& getBindingKey() const;
1406+
Entry const& getBindingKey() const;
14071407
};
14081408

14091409
struct PreslicePolicySorted : public PreslicePolicyBase {
@@ -1428,7 +1428,7 @@ struct PresliceBase : public Policy {
14281428
const std::string binding;
14291429

14301430
PresliceBase(expressions::BindingNode index_)
1431-
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
1431+
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
14321432
{
14331433
}
14341434

@@ -1508,7 +1508,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15081508
{
15091509
if constexpr (OPT) {
15101510
if (container.isMissing()) {
1511-
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
1511+
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
15121512
}
15131513
}
15141514
uint64_t offset = 0;
@@ -1545,7 +1545,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15451545
{
15461546
if constexpr (OPT) {
15471547
if (container.isMissing()) {
1548-
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
1548+
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
15491549
}
15501550
}
15511551
auto selection = container.getSliceFor(value);
@@ -1574,7 +1574,7 @@ auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, framework:
15741574
{
15751575
if constexpr (OPT) {
15761576
if (container.isMissing()) {
1577-
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.second.c_str());
1577+
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.key.c_str());
15781578
}
15791579
}
15801580
uint64_t offset = 0;

Framework/Core/include/Framework/AnalysisManagers.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,39 +534,43 @@ static void setGroupedCombination(C& comb, TG& grouping, std::tuple<Ts...>& asso
534534
/// Preslice handling
535535
template <typename T>
536536
requires(!is_preslice<T>)
537-
bool registerCache(T&, std::vector<StringPair>&, std::vector<StringPair>&)
537+
bool registerCache(T&, Cache&, Cache&)
538538
{
539539
return false;
540540
}
541541

542542
template <is_preslice T>
543543
requires std::same_as<typename T::policy_t, framework::PreslicePolicySorted>
544-
bool registerCache(T& preslice, std::vector<StringPair>& bsks, std::vector<StringPair>&)
544+
bool registerCache(T& preslice, Cache& bsks, Cache&)
545545
{
546546
if constexpr (T::optional) {
547547
if (preslice.binding == "[MISSING]") {
548548
return true;
549549
}
550550
}
551-
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
551+
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
552552
if (locate == bsks.end()) {
553553
bsks.emplace_back(preslice.getBindingKey());
554+
} else if (locate->enabled == false) {
555+
locate->enabled = true;
554556
}
555557
return true;
556558
}
557559

558560
template <is_preslice T>
559561
requires std::same_as<typename T::policy_t, framework::PreslicePolicyGeneral>
560-
bool registerCache(T& preslice, std::vector<StringPair>&, std::vector<StringPair>& bsksU)
562+
bool registerCache(T& preslice, Cache&, Cache& bsksU)
561563
{
562564
if constexpr (T::optional) {
563565
if (preslice.binding == "[MISSING]") {
564566
return true;
565567
}
566568
}
567-
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
569+
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
568570
if (locate == bsksU.end()) {
569571
bsksU.emplace_back(preslice.getBindingKey());
572+
} else if (locate->enabled == false) {
573+
locate->enabled = true;
570574
}
571575
return true;
572576
}

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ concept is_enumeration = is_enumeration_v<std::decay_t<T>>;
6666
namespace {
6767
struct AnalysisDataProcessorBuilder {
6868
template <typename G, typename... Args>
69-
static void addGroupingCandidates(std::vector<StringPair>& bk, std::vector<StringPair>& bku)
69+
static void addGroupingCandidates(Cache& bk, Cache& bku, bool enabled)
7070
{
71-
[&bk, &bku]<typename... As>(framework::pack<As...>) mutable {
71+
[&bk, &bku, enabled]<typename... As>(framework::pack<As...>) mutable {
7272
std::string key;
7373
if constexpr (soa::is_iterator<std::decay_t<G>>) {
7474
key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType<std::decay_t<G>>());
7575
}
76-
([&bk, &bku, &key]() mutable {
76+
([&bk, &bku, &key, enabled]() mutable {
7777
if constexpr (soa::relatedByIndex<std::decay_t<G>, std::decay_t<As>>()) {
7878
auto binding = soa::getLabelFromTypeForKey<std::decay_t<As>>(key);
7979
if constexpr (o2::soa::is_smallgroups<std::decay_t<As>>) {
80-
framework::updatePairList(bku, binding, key);
80+
framework::updatePairList(bku, binding, key, enabled);
8181
} else {
82-
framework::updatePairList(bk, binding, key);
82+
framework::updatePairList(bk, binding, key, enabled);
8383
}
8484
}
8585
}(),
@@ -147,7 +147,7 @@ struct AnalysisDataProcessorBuilder {
147147
/// helper to parse the process arguments
148148
/// 1. enumeration (must be the only argument)
149149
template <typename R, typename C, is_enumeration A>
150-
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, std::vector<StringPair>&, std::vector<StringPair>&)
150+
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, Cache&, Cache&)
151151
{
152152
std::vector<ConfigParamSpec> inputMetadata;
153153
// FIXME: for the moment we do not support begin, end and step.
@@ -156,17 +156,17 @@ struct AnalysisDataProcessorBuilder {
156156

157157
/// 2. grouping case - 1st argument is an iterator
158158
template <typename R, typename C, soa::is_iterator A, soa::is_table... Args>
159-
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>& bk, std::vector<StringPair>& bku)
159+
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache& bk, Cache& bku)
160160
requires(std::is_lvalue_reference_v<A> && (std::is_lvalue_reference_v<Args> && ...))
161161
{
162-
addGroupingCandidates<A, Args...>(bk, bku);
162+
addGroupingCandidates<A, Args...>(bk, bku, value);
163163
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(A, Args...)>();
164164
addInputsAndExpressions<typename std::decay_t<A>::parent_t, Args...>(hash, name, value, inputs, eInfos);
165165
}
166166

167167
/// 3. generic case
168168
template <typename R, typename C, soa::is_table... Args>
169-
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>&, std::vector<StringPair>&)
169+
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache&, Cache&)
170170
requires(std::is_lvalue_reference_v<Args> && ...)
171171
{
172172
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(Args...)>();
@@ -480,8 +480,8 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
480480
std::vector<InputSpec> inputs;
481481
std::vector<ConfigParamSpec> options;
482482
std::vector<ExpressionInfo> expressionInfos;
483-
std::vector<StringPair> bindingsKeys;
484-
std::vector<StringPair> bindingsKeysUnsorted;
483+
Cache bindingsKeys;
484+
Cache bindingsKeysUnsorted;
485485

486486
/// make sure options and configurables are set before expression infos are created
487487
homogeneous_apply_refs([&options, &hash](auto& element) { return analysis_task_parsers::appendOption(options, element); }, *task.get());

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,51 +34,64 @@ struct SliceInfoUnsortedPtr {
3434
gsl::span<int64_t const> getSliceFor(int value) const;
3535
};
3636

37-
using StringPair = std::pair<std::string, std::string>;
37+
struct Entry {
38+
std::string binding;
39+
std::string key;
40+
bool enabled;
41+
42+
Entry(std::string b, std::string k, bool e = true)
43+
: binding{b},
44+
key{k},
45+
enabled{e}
46+
{
47+
}
48+
};
49+
50+
using Cache = std::vector<Entry>;
3851

39-
void updatePairList(std::vector<StringPair>& list, std::string const& binding, std::string const& key);
52+
void updatePairList(Cache& list, std::string const& binding, std::string const& key, bool enabled);
4053

4154
struct ArrowTableSlicingCacheDef {
4255
constexpr static ServiceKind service_kind = ServiceKind::Global;
43-
std::vector<StringPair> bindingsKeys;
44-
std::vector<StringPair> bindingsKeysUnsorted;
56+
Cache bindingsKeys;
57+
Cache bindingsKeysUnsorted;
4558

46-
void setCaches(std::vector<StringPair>&& bsks);
47-
void setCachesUnsorted(std::vector<StringPair>&& bsks);
59+
void setCaches(Cache&& bsks);
60+
void setCachesUnsorted(Cache&& bsks);
4861
};
4962

5063
struct ArrowTableSlicingCache {
5164
constexpr static ServiceKind service_kind = ServiceKind::Stream;
5265

53-
std::vector<StringPair> bindingsKeys;
66+
Cache bindingsKeys;
5467
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
5568
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
5669

57-
std::vector<StringPair> bindingsKeysUnsorted;
70+
Cache bindingsKeysUnsorted;
5871
std::vector<std::vector<int>> valuesUnsorted;
5972
std::vector<ListVector> groups;
6073

61-
ArrowTableSlicingCache(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
74+
ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {});
6275

6376
// set caching information externally
64-
void setCaches(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
77+
void setCaches(Cache&& bsks, Cache&& bsksUnsorted = {});
6578

6679
// update slicing info cache entry (assumes it is already present)
6780
arrow::Status updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table);
6881
arrow::Status updateCacheEntryUnsorted(int pos, std::shared_ptr<arrow::Table> const& table);
6982

7083
// helper to locate cache position
71-
std::pair<int, bool> getCachePos(StringPair const& bindingKey) const;
72-
int getCachePosSortedFor(StringPair const& bindingKey) const;
73-
int getCachePosUnsortedFor(StringPair const& bindingKey) const;
84+
std::pair<int, bool> getCachePos(Entry const& bindingKey) const;
85+
int getCachePosSortedFor(Entry const& bindingKey) const;
86+
int getCachePosUnsortedFor(Entry const& bindingKey) const;
7487

7588
// get slice from cache for a given value
76-
SliceInfoPtr getCacheFor(StringPair const& bindingKey) const;
77-
SliceInfoUnsortedPtr getCacheUnsortedFor(StringPair const& bindingKey) const;
89+
SliceInfoPtr getCacheFor(Entry const& bindingKey) const;
90+
SliceInfoUnsortedPtr getCacheUnsortedFor(Entry const& bindingKey) const;
7891
SliceInfoPtr getCacheForPos(int pos) const;
7992
SliceInfoUnsortedPtr getCacheUnsortedForPos(int pos) const;
8093

81-
static void validateOrder(StringPair const& bindingKey, std::shared_ptr<arrow::Table> const& input);
94+
static void validateOrder(Entry const& bindingKey, std::shared_ptr<arrow::Table> const& input);
8295
};
8396
} // namespace o2::framework
8497

Framework/Core/include/Framework/GroupSlicer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct GroupSlicer {
5555
{
5656
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
5757
auto binding = o2::soa::getLabelFromTypeForKey<std::decay_t<T>>(mIndexColumnName);
58-
auto bk = std::make_pair(binding, mIndexColumnName);
58+
auto bk = Entry(binding, mIndexColumnName);
5959
if constexpr (!o2::soa::is_smallgroups<std::decay_t<T>>) {
6060
if (table.size() == 0) {
6161
return;

Framework/Core/src/ASoA.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ bool PreslicePolicyBase::isMissing() const
197197
return binding == "[MISSING]";
198198
}
199199

200-
StringPair const& PreslicePolicyBase::getBindingKey() const
200+
Entry const& PreslicePolicyBase::getBindingKey() const
201201
{
202202
return bindingKey;
203203
}

Framework/Core/src/ArrowSupport.cxx

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -567,26 +567,27 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec()
567567
.name = "arrow-slicing-cache",
568568
.uniqueId = CommonServices::simpleServiceId<ArrowTableSlicingCache>(),
569569
.init = [](ServiceRegistryRef services, DeviceState&, fair::mq::ProgOptions&) { return ServiceHandle{TypeIdHelpers::uniqueId<ArrowTableSlicingCache>(),
570-
new ArrowTableSlicingCache(std::vector<std::pair<std::string, std::string>>{services.get<ArrowTableSlicingCacheDef>().bindingsKeys}, std::vector{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
570+
new ArrowTableSlicingCache(Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeys},
571+
Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
571572
ServiceKind::Stream, typeid(ArrowTableSlicingCache).name()}; },
572573
.configure = CommonServices::noConfiguration(),
573574
.preProcessing = [](ProcessingContext& pc, void* service_ptr) {
574575
auto* service = static_cast<ArrowTableSlicingCache*>(service_ptr);
575576
auto& caches = service->bindingsKeys;
576-
for (auto i = 0; i < caches.size(); ++i) {
577-
if (pc.inputs().getPos(caches[i].first.c_str()) >= 0) {
578-
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].first.c_str())->asArrowTable());
577+
for (auto i = 0u; i < caches.size(); ++i) {
578+
if (caches[i].enabled && pc.inputs().getPos(caches[i].binding.c_str()) >= 0) {
579+
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].binding.c_str())->asArrowTable());
579580
if (!status.ok()) {
580-
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].first.c_str(), caches[i].second.c_str());
581+
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].binding.c_str(), caches[i].key.c_str());
581582
}
582583
}
583584
}
584585
auto& unsortedCaches = service->bindingsKeysUnsorted;
585-
for (auto i = 0; i < unsortedCaches.size(); ++i) {
586-
if (pc.inputs().getPos(unsortedCaches[i].first.c_str()) >= 0) {
587-
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].first.c_str())->asArrowTable());
586+
for (auto i = 0u; i < unsortedCaches.size(); ++i) {
587+
if (unsortedCaches[i].enabled && pc.inputs().getPos(unsortedCaches[i].binding.c_str()) >= 0) {
588+
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].binding.c_str())->asArrowTable());
588589
if (!status.ok()) {
589-
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].first.c_str(), unsortedCaches[i].second.c_str());
590+
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].binding.c_str(), unsortedCaches[i].key.c_str());
590591
}
591592
}
592593
} },

0 commit comments

Comments
 (0)