Skip to content

Commit 1a75f88

Browse files
committed
DPL Analysis: prevent slice cache from updating when not required by
enabled process functions
1 parent fe6cd7c commit 1a75f88

File tree

9 files changed

+96
-72
lines changed

9 files changed

+96
-72
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,10 +1397,10 @@ namespace o2::framework
13971397

13981398
struct PreslicePolicyBase {
13991399
const std::string binding;
1400-
StringPair bindingKey;
1400+
Entry bindingKey;
14011401

14021402
bool isMissing() const;
1403-
StringPair const& getBindingKey() const;
1403+
Entry const& getBindingKey() const;
14041404
};
14051405

14061406
struct PreslicePolicySorted : public PreslicePolicyBase {
@@ -1425,7 +1425,7 @@ struct PresliceBase : public Policy {
14251425
const std::string binding;
14261426

14271427
PresliceBase(expressions::BindingNode index_)
1428-
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
1428+
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
14291429
{
14301430
}
14311431

Framework/Core/include/Framework/AnalysisManagers.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,39 +512,43 @@ static void setGroupedCombination(C& comb, TG& grouping, std::tuple<Ts...>& asso
512512
/// Preslice handling
513513
template <typename T>
514514
requires(!is_preslice<T>)
515-
bool registerCache(T&, std::vector<StringPair>&, std::vector<StringPair>&)
515+
bool registerCache(T&, Cache&, Cache&)
516516
{
517517
return false;
518518
}
519519

520520
template <is_preslice T>
521521
requires std::same_as<typename T::policy_t, framework::PreslicePolicySorted>
522-
bool registerCache(T& preslice, std::vector<StringPair>& bsks, std::vector<StringPair>&)
522+
bool registerCache(T& preslice, Cache& bsks, Cache&)
523523
{
524524
if constexpr (T::optional) {
525525
if (preslice.binding == "[MISSING]") {
526526
return true;
527527
}
528528
}
529-
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
529+
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
530530
if (locate == bsks.end()) {
531531
bsks.emplace_back(preslice.getBindingKey());
532+
} else if (locate->enabled == false) {
533+
locate->enabled = true;
532534
}
533535
return true;
534536
}
535537

536538
template <is_preslice T>
537539
requires std::same_as<typename T::policy_t, framework::PreslicePolicyGeneral>
538-
bool registerCache(T& preslice, std::vector<StringPair>&, std::vector<StringPair>& bsksU)
540+
bool registerCache(T& preslice, Cache&, Cache& bsksU)
539541
{
540542
if constexpr (T::optional) {
541543
if (preslice.binding == "[MISSING]") {
542544
return true;
543545
}
544546
}
545-
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
547+
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
546548
if (locate == bsksU.end()) {
547549
bsksU.emplace_back(preslice.getBindingKey());
550+
} else if (locate->enabled == false) {
551+
locate->enabled = true;
548552
}
549553
return true;
550554
}

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,20 @@ struct AnalysisDataProcessorBuilder {
106106
}
107107

108108
template <typename G, typename... Args>
109-
static void addGroupingCandidates(std::vector<StringPair>& bk, std::vector<StringPair>& bku)
109+
static void addGroupingCandidates(Cache& bk, Cache& bku, bool enabled)
110110
{
111-
[&bk, &bku]<typename... As>(framework::pack<As...>) mutable {
111+
[&bk, &bku, enabled]<typename... As>(framework::pack<As...>) mutable {
112112
std::string key;
113113
if constexpr (soa::is_iterator<std::decay_t<G>>) {
114114
key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType<std::decay_t<G>>());
115115
}
116-
([&bk, &bku, &key]() mutable {
116+
([&bk, &bku, &key, enabled]() mutable {
117117
if constexpr (soa::relatedByIndex<std::decay_t<G>, std::decay_t<As>>()) {
118118
auto binding = soa::getLabelFromTypeForKey<std::decay_t<As>>(key);
119119
if constexpr (o2::soa::is_smallgroups<std::decay_t<As>>) {
120-
framework::updatePairList(bku, binding, key);
120+
framework::updatePairList(bku, binding, key, enabled);
121121
} else {
122-
framework::updatePairList(bk, binding, key);
122+
framework::updatePairList(bk, binding, key, enabled);
123123
}
124124
}
125125
}(),
@@ -192,7 +192,7 @@ struct AnalysisDataProcessorBuilder {
192192
/// helper to parse the process arguments
193193
/// 1. enumeration (must be the only argument)
194194
template <typename R, typename C, is_enumeration A>
195-
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, std::vector<StringPair>&, std::vector<StringPair>&)
195+
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, Cache&, Cache&)
196196
{
197197
std::vector<ConfigParamSpec> inputMetadata;
198198
// FIXME: for the moment we do not support begin, end and step.
@@ -201,17 +201,17 @@ struct AnalysisDataProcessorBuilder {
201201

202202
/// 2. grouping case - 1st argument is an iterator
203203
template <typename R, typename C, soa::is_iterator A, soa::is_table... Args>
204-
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>& bk, std::vector<StringPair>& bku)
204+
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache& bk, Cache& bku)
205205
requires(std::is_lvalue_reference_v<A> && (std::is_lvalue_reference_v<Args> && ...))
206206
{
207-
addGroupingCandidates<A, Args...>(bk, bku);
207+
addGroupingCandidates<A, Args...>(bk, bku, value);
208208
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(A, Args...)>();
209209
addInputsAndExpressions<typename std::decay_t<A>::parent_t, Args...>(hash, name, value, inputs, eInfos);
210210
}
211211

212212
/// 3. generic case
213213
template <typename R, typename C, soa::is_table... Args>
214-
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>&, std::vector<StringPair>&)
214+
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, Cache&, Cache&)
215215
requires(std::is_lvalue_reference_v<Args> && ...)
216216
{
217217
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(Args...)>();
@@ -525,8 +525,8 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
525525
std::vector<InputSpec> inputs;
526526
std::vector<ConfigParamSpec> options;
527527
std::vector<ExpressionInfo> expressionInfos;
528-
std::vector<StringPair> bindingsKeys;
529-
std::vector<StringPair> bindingsKeysUnsorted;
528+
Cache bindingsKeys;
529+
Cache bindingsKeysUnsorted;
530530

531531
/// make sure options and configurables are set before expression infos are created
532532
homogeneous_apply_refs([&options, &hash](auto& element) { return analysis_task_parsers::appendOption(options, element); }, *task.get());

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,51 +34,64 @@ struct SliceInfoUnsortedPtr {
3434
gsl::span<int64_t const> getSliceFor(int value) const;
3535
};
3636

37-
using StringPair = std::pair<std::string, std::string>;
37+
struct Entry {
38+
std::string binding;
39+
std::string key;
40+
bool enabled;
41+
42+
Entry(std::string b, std::string k, bool e = true)
43+
: binding{b},
44+
key{k},
45+
enabled{e}
46+
{
47+
}
48+
};
49+
50+
using Cache = std::vector<Entry>;
3851

39-
void updatePairList(std::vector<StringPair>& list, std::string const& binding, std::string const& key);
52+
void updatePairList(Cache& list, std::string const& binding, std::string const& key, bool enabled);
4053

4154
struct ArrowTableSlicingCacheDef {
4255
constexpr static ServiceKind service_kind = ServiceKind::Global;
43-
std::vector<StringPair> bindingsKeys;
44-
std::vector<StringPair> bindingsKeysUnsorted;
56+
Cache bindingsKeys;
57+
Cache bindingsKeysUnsorted;
4558

46-
void setCaches(std::vector<StringPair>&& bsks);
47-
void setCachesUnsorted(std::vector<StringPair>&& bsks);
59+
void setCaches(Cache&& bsks);
60+
void setCachesUnsorted(Cache&& bsks);
4861
};
4962

5063
struct ArrowTableSlicingCache {
5164
constexpr static ServiceKind service_kind = ServiceKind::Stream;
5265

53-
std::vector<StringPair> bindingsKeys;
66+
Cache bindingsKeys;
5467
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
5568
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
5669

57-
std::vector<StringPair> bindingsKeysUnsorted;
70+
Cache bindingsKeysUnsorted;
5871
std::vector<std::vector<int>> valuesUnsorted;
5972
std::vector<ListVector> groups;
6073

61-
ArrowTableSlicingCache(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
74+
ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {});
6275

6376
// set caching information externally
64-
void setCaches(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
77+
void setCaches(Cache&& bsks, Cache&& bsksUnsorted = {});
6578

6679
// update slicing info cache entry (assumes it is already present)
6780
arrow::Status updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table);
6881
arrow::Status updateCacheEntryUnsorted(int pos, std::shared_ptr<arrow::Table> const& table);
6982

7083
// helper to locate cache position
71-
std::pair<int, bool> getCachePos(StringPair const& bindingKey) const;
72-
int getCachePosSortedFor(StringPair const& bindingKey) const;
73-
int getCachePosUnsortedFor(StringPair const& bindingKey) const;
84+
std::pair<int, bool> getCachePos(Entry const& bindingKey) const;
85+
int getCachePosSortedFor(Entry const& bindingKey) const;
86+
int getCachePosUnsortedFor(Entry const& bindingKey) const;
7487

7588
// get slice from cache for a given value
76-
SliceInfoPtr getCacheFor(StringPair const& bindingKey) const;
77-
SliceInfoUnsortedPtr getCacheUnsortedFor(StringPair const& bindingKey) const;
89+
SliceInfoPtr getCacheFor(Entry const& bindingKey) const;
90+
SliceInfoUnsortedPtr getCacheUnsortedFor(Entry const& bindingKey) const;
7891
SliceInfoPtr getCacheForPos(int pos) const;
7992
SliceInfoUnsortedPtr getCacheUnsortedForPos(int pos) const;
8093

81-
static void validateOrder(StringPair const& bindingKey, std::shared_ptr<arrow::Table> const& input);
94+
static void validateOrder(Entry const& bindingKey, std::shared_ptr<arrow::Table> const& input);
8295
};
8396
} // namespace o2::framework
8497

Framework/Core/include/Framework/GroupSlicer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct GroupSlicer {
5555
{
5656
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
5757
auto binding = o2::soa::getLabelFromTypeForKey<std::decay_t<T>>(mIndexColumnName);
58-
auto bk = std::make_pair(binding, mIndexColumnName);
58+
auto bk = Entry(binding, mIndexColumnName);
5959
if constexpr (!o2::soa::is_smallgroups<std::decay_t<T>>) {
6060
if (table.size() == 0) {
6161
return;

Framework/Core/src/ASoA.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ bool PreslicePolicyBase::isMissing() const
197197
return binding == "[MISSING]";
198198
}
199199

200-
StringPair const& PreslicePolicyBase::getBindingKey() const
200+
Entry const& PreslicePolicyBase::getBindingKey() const
201201
{
202202
return bindingKey;
203203
}

Framework/Core/src/ArrowSupport.cxx

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -566,27 +566,29 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec()
566566
return ServiceSpec{
567567
.name = "arrow-slicing-cache",
568568
.uniqueId = CommonServices::simpleServiceId<ArrowTableSlicingCache>(),
569-
.init = [](ServiceRegistryRef services, DeviceState&, fair::mq::ProgOptions&) { return ServiceHandle{TypeIdHelpers::uniqueId<ArrowTableSlicingCache>(),
570-
new ArrowTableSlicingCache(std::vector<std::pair<std::string, std::string>>{services.get<ArrowTableSlicingCacheDef>().bindingsKeys}, std::vector{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
571-
ServiceKind::Stream, typeid(ArrowTableSlicingCache).name()}; },
569+
.init = [](ServiceRegistryRef services, DeviceState&, fair::mq::ProgOptions&) {
570+
return ServiceHandle{TypeIdHelpers::uniqueId<ArrowTableSlicingCache>(),
571+
new ArrowTableSlicingCache(Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeys},
572+
Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
573+
ServiceKind::Stream, typeid(ArrowTableSlicingCache).name()}; },
572574
.configure = CommonServices::noConfiguration(),
573575
.preProcessing = [](ProcessingContext& pc, void* service_ptr) {
574576
auto* service = static_cast<ArrowTableSlicingCache*>(service_ptr);
575577
auto& caches = service->bindingsKeys;
576-
for (auto i = 0; i < caches.size(); ++i) {
577-
if (pc.inputs().getPos(caches[i].first.c_str()) >= 0) {
578-
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].first.c_str())->asArrowTable());
578+
for (auto i = 0u; i < caches.size(); ++i) {
579+
if (caches[i].enabled && pc.inputs().getPos(caches[i].binding.c_str()) >= 0) {
580+
auto status = service->updateCacheEntry(i, pc.inputs().get<TableConsumer>(caches[i].binding.c_str())->asArrowTable());
579581
if (!status.ok()) {
580-
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].first.c_str(), caches[i].second.c_str());
582+
throw runtime_error_f("Failed to update slice cache for %s/%s", caches[i].binding.c_str(), caches[i].key.c_str());
581583
}
582584
}
583585
}
584586
auto& unsortedCaches = service->bindingsKeysUnsorted;
585-
for (auto i = 0; i < unsortedCaches.size(); ++i) {
586-
if (pc.inputs().getPos(unsortedCaches[i].first.c_str()) >= 0) {
587-
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].first.c_str())->asArrowTable());
587+
for (auto i = 0u; i < unsortedCaches.size(); ++i) {
588+
if (unsortedCaches[i].enabled && pc.inputs().getPos(unsortedCaches[i].binding.c_str()) >= 0) {
589+
auto status = service->updateCacheEntryUnsorted(i, pc.inputs().get<TableConsumer>(unsortedCaches[i].binding.c_str())->asArrowTable());
588590
if (!status.ok()) {
589-
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].first.c_str(), unsortedCaches[i].second.c_str());
591+
throw runtime_error_f("failed to update slice cache (unsorted) for %s/%s", unsortedCaches[i].binding.c_str(), unsortedCaches[i].key.c_str());
590592
}
591593
}
592594
} },

0 commit comments

Comments
 (0)