Skip to content

Commit 35e208b

Browse files
authored
DPL Analysis: prevent slice cache from updating unnecessarily (#14257)
* Cache setup now only happens after init when process configurables' values are final * Add inline contrained functions to avoid using "overloaded" * add error messages for unexpected situations
1 parent 4d20c8d commit 35e208b

File tree

9 files changed

+167
-93
lines changed

9 files changed

+167
-93
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,10 +1400,10 @@ namespace o2::framework
14001400

14011401
struct PreslicePolicyBase {
14021402
const std::string binding;
1403-
StringPair bindingKey;
1403+
Entry bindingKey;
14041404

14051405
bool isMissing() const;
1406-
StringPair const& getBindingKey() const;
1406+
Entry const& getBindingKey() const;
14071407
};
14081408

14091409
struct PreslicePolicySorted : public PreslicePolicyBase {
@@ -1428,7 +1428,7 @@ struct PresliceBase : public Policy {
14281428
const std::string binding;
14291429

14301430
PresliceBase(expressions::BindingNode index_)
1431-
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, std::make_pair(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
1431+
: Policy{PreslicePolicyBase{{o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name})}, Entry(o2::soa::getLabelFromTypeForKey<T, OPT>(std::string{index_.name}), std::string{index_.name})}, {}}
14321432
{
14331433
}
14341434

@@ -1508,7 +1508,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15081508
{
15091509
if constexpr (OPT) {
15101510
if (container.isMissing()) {
1511-
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
1511+
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
15121512
}
15131513
}
15141514
uint64_t offset = 0;
@@ -1545,7 +1545,7 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15451545
{
15461546
if constexpr (OPT) {
15471547
if (container.isMissing()) {
1548-
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.second.c_str());
1548+
missingOptionalPreslice(getLabelFromType<std::decay_t<T>>().data(), container.bindingKey.key.c_str());
15491549
}
15501550
}
15511551
auto selection = container.getSliceFor(value);
@@ -1574,7 +1574,7 @@ auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, framework:
15741574
{
15751575
if constexpr (OPT) {
15761576
if (container.isMissing()) {
1577-
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.second.c_str());
1577+
missingOptionalPreslice(getLabelFromType<T>().data(), container.bindingKey.key.c_str());
15781578
}
15791579
}
15801580
uint64_t offset = 0;

Framework/Core/include/Framework/AnalysisManagers.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,39 +534,43 @@ static void setGroupedCombination(C& comb, TG& grouping, std::tuple<Ts...>& asso
534534
/// Preslice handling
535535
template <typename T>
536536
requires(!is_preslice<T>)
537-
bool registerCache(T&, std::vector<StringPair>&, std::vector<StringPair>&)
537+
bool registerCache(T&, Cache&, Cache&)
538538
{
539539
return false;
540540
}
541541

542542
template <is_preslice T>
543543
requires std::same_as<typename T::policy_t, framework::PreslicePolicySorted>
544-
bool registerCache(T& preslice, std::vector<StringPair>& bsks, std::vector<StringPair>&)
544+
bool registerCache(T& preslice, Cache& bsks, Cache&)
545545
{
546546
if constexpr (T::optional) {
547547
if (preslice.binding == "[MISSING]") {
548548
return true;
549549
}
550550
}
551-
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
551+
auto locate = std::find_if(bsks.begin(), bsks.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
552552
if (locate == bsks.end()) {
553553
bsks.emplace_back(preslice.getBindingKey());
554+
} else if (locate->enabled == false) {
555+
locate->enabled = true;
554556
}
555557
return true;
556558
}
557559

558560
template <is_preslice T>
559561
requires std::same_as<typename T::policy_t, framework::PreslicePolicyGeneral>
560-
bool registerCache(T& preslice, std::vector<StringPair>&, std::vector<StringPair>& bsksU)
562+
bool registerCache(T& preslice, Cache&, Cache& bsksU)
561563
{
562564
if constexpr (T::optional) {
563565
if (preslice.binding == "[MISSING]") {
564566
return true;
565567
}
566568
}
567-
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.first == preslice.bindingKey.first) && (entry.second == preslice.bindingKey.second); });
569+
auto locate = std::find_if(bsksU.begin(), bsksU.end(), [&](auto const& entry) { return (entry.binding == preslice.bindingKey.binding) && (entry.key == preslice.bindingKey.key); });
568570
if (locate == bsksU.end()) {
569571
bsksU.emplace_back(preslice.getBindingKey());
572+
} else if (locate->enabled == false) {
573+
locate->enabled = true;
570574
}
571575
return true;
572576
}

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,18 @@ concept is_enumeration = is_enumeration_v<std::decay_t<T>>;
6565
// the contents of an AnalysisTask...
6666
namespace {
6767
struct AnalysisDataProcessorBuilder {
68-
template <typename G, typename... Args>
69-
static void addGroupingCandidates(std::vector<StringPair>& bk, std::vector<StringPair>& bku)
68+
template <soa::is_iterator G, typename... Args>
69+
static void addGroupingCandidates(Cache& bk, Cache& bku, bool enabled)
7070
{
71-
[&bk, &bku]<typename... As>(framework::pack<As...>) mutable {
72-
std::string key;
73-
if constexpr (soa::is_iterator<std::decay_t<G>>) {
74-
key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType<std::decay_t<G>>());
75-
}
76-
([&bk, &bku, &key]() mutable {
71+
[&bk, &bku, enabled]<typename... As>(framework::pack<As...>) mutable {
72+
auto key = std::string{"fIndex"} + o2::framework::cutString(soa::getLabelFromType<std::decay_t<G>>());
73+
([&bk, &bku, &key, enabled]() mutable {
7774
if constexpr (soa::relatedByIndex<std::decay_t<G>, std::decay_t<As>>()) {
7875
auto binding = soa::getLabelFromTypeForKey<std::decay_t<As>>(key);
7976
if constexpr (o2::soa::is_smallgroups<std::decay_t<As>>) {
80-
framework::updatePairList(bku, binding, key);
77+
framework::updatePairList(bku, binding, key, enabled);
8178
} else {
82-
framework::updatePairList(bk, binding, key);
79+
framework::updatePairList(bk, binding, key, enabled);
8380
}
8481
}
8582
}(),
@@ -145,34 +142,72 @@ struct AnalysisDataProcessorBuilder {
145142
}
146143

147144
/// helper to parse the process arguments
145+
template <typename T>
146+
inline static bool requestInputsFromArgs(T&, std::string const&, std::vector<InputSpec>&, std::vector<ExpressionInfo>&)
147+
{
148+
return false;
149+
}
150+
template <is_process_configurable T>
151+
inline static bool requestInputsFromArgs(T& pc, std::string const& name, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eis)
152+
{
153+
AnalysisDataProcessorBuilder::inputsFromArgs(pc.process, (name + "/" + pc.name).c_str(), pc.value, inputs, eis);
154+
return true;
155+
}
156+
template <typename T>
157+
inline static bool requestCacheFromArgs(T&, Cache&, Cache&)
158+
{
159+
return false;
160+
}
161+
template <is_process_configurable T>
162+
inline static bool requestCacheFromArgs(T& pc, Cache& bk, Cache& bku)
163+
{
164+
AnalysisDataProcessorBuilder::cacheFromArgs(pc.process, pc.value, bk, bku);
165+
return true;
166+
}
148167
/// 1. enumeration (must be the only argument)
149168
template <typename R, typename C, is_enumeration A>
150-
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&, std::vector<StringPair>&, std::vector<StringPair>&)
169+
static void inputsFromArgs(R (C::*)(A), const char* /*name*/, bool /*value*/, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>&) //, Cache&, Cache&)
151170
{
152171
std::vector<ConfigParamSpec> inputMetadata;
153172
// FIXME: for the moment we do not support begin, end and step.
154173
DataSpecUtils::updateInputList(inputs, InputSpec{"enumeration", "DPL", "ENUM", 0, Lifetime::Enumeration, inputMetadata});
155174
}
156175

157-
/// 2. grouping case - 1st argument is an iterator
176+
/// 2. 1st argument is an iterator
158177
template <typename R, typename C, soa::is_iterator A, soa::is_table... Args>
159-
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>& bk, std::vector<StringPair>& bku)
178+
static void inputsFromArgs(R (C::*)(A, Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos) //, Cache& bk, Cache& bku)
160179
requires(std::is_lvalue_reference_v<A> && (std::is_lvalue_reference_v<Args> && ...))
161180
{
162-
addGroupingCandidates<A, Args...>(bk, bku);
163181
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(A, Args...)>();
164182
addInputsAndExpressions<typename std::decay_t<A>::parent_t, Args...>(hash, name, value, inputs, eInfos);
165183
}
166184

167185
/// 3. generic case
168186
template <typename R, typename C, soa::is_table... Args>
169-
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, std::vector<StringPair>&, std::vector<StringPair>&)
187+
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos) //, Cache&, Cache&)
170188
requires(std::is_lvalue_reference_v<Args> && ...)
171189
{
172190
constexpr auto hash = o2::framework::TypeIdHelpers::uniqueId<R (C::*)(Args...)>();
173191
addInputsAndExpressions<Args...>(hash, name, value, inputs, eInfos);
174192
}
175193

194+
/// 1. enumeration (no grouping)
195+
template <typename R, typename C, is_enumeration A>
196+
static void cacheFromArgs(R (C::*)(A), bool, Cache&, Cache&)
197+
{
198+
}
199+
/// 2. iterator (the only grouping case)
200+
template <typename R, typename C, soa::is_iterator A, soa::is_table... Args>
201+
static void cacheFromArgs(R (C::*)(A, Args...), bool value, Cache& bk, Cache& bku)
202+
{
203+
addGroupingCandidates<A, Args...>(bk, bku, value);
204+
}
205+
/// 3. generic case (no grouping)
206+
template <typename R, typename C, soa::is_table A, soa::is_table... Args>
207+
static void cacheFromArgs(R (C::*)(A, Args...), bool, Cache&, Cache&)
208+
{
209+
}
210+
176211
template <soa::TableRef R>
177212
static auto extractTableFromRecord(InputRecord& record)
178213
{
@@ -480,8 +515,6 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
480515
std::vector<InputSpec> inputs;
481516
std::vector<ConfigParamSpec> options;
482517
std::vector<ExpressionInfo> expressionInfos;
483-
std::vector<StringPair> bindingsKeys;
484-
std::vector<StringPair> bindingsKeysUnsorted;
485518

486519
/// make sure options and configurables are set before expression infos are created
487520
homogeneous_apply_refs([&options, &hash](auto& element) { return analysis_task_parsers::appendOption(options, element); }, *task.get());
@@ -490,23 +523,15 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
490523

491524
/// parse process functions defined by corresponding configurables
492525
if constexpr (requires { &T::process; }) {
493-
AnalysisDataProcessorBuilder::inputsFromArgs(&T::process, "default", true, inputs, expressionInfos, bindingsKeys, bindingsKeysUnsorted);
526+
AnalysisDataProcessorBuilder::inputsFromArgs(&T::process, "default", true, inputs, expressionInfos);
494527
}
495528
homogeneous_apply_refs(
496-
overloaded{
497-
[name = name_str, &expressionInfos, &inputs, &bindingsKeys, &bindingsKeysUnsorted](framework::is_process_configurable auto& x) mutable {
498-
// this pushes (argumentIndex,processHash,schemaPtr,nullptr) into expressionInfos for arguments that are Filtered/filtered_iterators
499-
AnalysisDataProcessorBuilder::inputsFromArgs(x.process, (name + "/" + x.name).c_str(), x.value, inputs, expressionInfos, bindingsKeys, bindingsKeysUnsorted);
500-
return true;
501-
},
502-
[](auto&) {
503-
return false;
504-
}},
529+
[name = name_str, &expressionInfos, &inputs](auto& x) mutable {
530+
// this pushes (argumentIndex, processHash, schemaPtr, nullptr) into expressionInfos for arguments that are Filtered/filtered_iterators
531+
return AnalysisDataProcessorBuilder::requestInputsFromArgs(x, name, inputs, expressionInfos);
532+
},
505533
*task.get());
506534

507-
// add preslice declarations to slicing cache definition
508-
homogeneous_apply_refs([&bindingsKeys, &bindingsKeysUnsorted](auto& element) { return analysis_task_parsers::registerCache(element, bindingsKeys, bindingsKeysUnsorted); }, *task.get());
509-
510535
// request base tables for spawnable extended tables and indices to be built
511536
// this checks for duplications
512537
homogeneous_apply_refs([&inputs](auto& element) {
@@ -526,7 +551,12 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
526551
requiredServices.insert(requiredServices.end(), arrowServices.begin(), arrowServices.end());
527552
homogeneous_apply_refs([&requiredServices](auto& element) { return analysis_task_parsers::addService(requiredServices, element); }, *task.get());
528553

529-
auto algo = AlgorithmSpec::InitCallback{[task = task, expressionInfos, bindingsKeys, bindingsKeysUnsorted](InitContext& ic) mutable {
554+
auto algo = AlgorithmSpec::InitCallback{[task = task, expressionInfos](InitContext& ic) mutable {
555+
Cache bindingsKeys;
556+
Cache bindingsKeysUnsorted;
557+
// add preslice declarations to slicing cache definition
558+
homogeneous_apply_refs([&bindingsKeys, &bindingsKeysUnsorted](auto& element) { return analysis_task_parsers::registerCache(element, bindingsKeys, bindingsKeysUnsorted); }, *task.get());
559+
530560
homogeneous_apply_refs([&ic](auto&& element) { return analysis_task_parsers::prepareOption(ic, element); }, *task.get());
531561
homogeneous_apply_refs([&ic](auto&& element) { return analysis_task_parsers::prepareService(ic, element); }, *task.get());
532562

@@ -556,6 +586,16 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
556586
task->init(ic);
557587
}
558588

589+
/// parse process functions to enable requested grouping caches - note that at this state process configurables have their final values
590+
if constexpr (requires { &T::process; }) {
591+
AnalysisDataProcessorBuilder::cacheFromArgs(&T::process, true, bindingsKeys, bindingsKeysUnsorted);
592+
}
593+
homogeneous_apply_refs(
594+
[&bindingsKeys, &bindingsKeysUnsorted](auto& x) mutable {
595+
return AnalysisDataProcessorBuilder::requestCacheFromArgs(x, bindingsKeys, bindingsKeysUnsorted);
596+
},
597+
*task.get());
598+
559599
ic.services().get<ArrowTableSlicingCacheDef>().setCaches(std::move(bindingsKeys));
560600
ic.services().get<ArrowTableSlicingCacheDef>().setCachesUnsorted(std::move(bindingsKeysUnsorted));
561601
// initialize global caches

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,51 +34,64 @@ struct SliceInfoUnsortedPtr {
3434
gsl::span<int64_t const> getSliceFor(int value) const;
3535
};
3636

37-
using StringPair = std::pair<std::string, std::string>;
37+
struct Entry {
38+
std::string binding;
39+
std::string key;
40+
bool enabled;
41+
42+
Entry(std::string b, std::string k, bool e = true)
43+
: binding{b},
44+
key{k},
45+
enabled{e}
46+
{
47+
}
48+
};
49+
50+
using Cache = std::vector<Entry>;
3851

39-
void updatePairList(std::vector<StringPair>& list, std::string const& binding, std::string const& key);
52+
void updatePairList(Cache& list, std::string const& binding, std::string const& key, bool enabled);
4053

4154
struct ArrowTableSlicingCacheDef {
4255
constexpr static ServiceKind service_kind = ServiceKind::Global;
43-
std::vector<StringPair> bindingsKeys;
44-
std::vector<StringPair> bindingsKeysUnsorted;
56+
Cache bindingsKeys;
57+
Cache bindingsKeysUnsorted;
4558

46-
void setCaches(std::vector<StringPair>&& bsks);
47-
void setCachesUnsorted(std::vector<StringPair>&& bsks);
59+
void setCaches(Cache&& bsks);
60+
void setCachesUnsorted(Cache&& bsks);
4861
};
4962

5063
struct ArrowTableSlicingCache {
5164
constexpr static ServiceKind service_kind = ServiceKind::Stream;
5265

53-
std::vector<StringPair> bindingsKeys;
66+
Cache bindingsKeys;
5467
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
5568
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
5669

57-
std::vector<StringPair> bindingsKeysUnsorted;
70+
Cache bindingsKeysUnsorted;
5871
std::vector<std::vector<int>> valuesUnsorted;
5972
std::vector<ListVector> groups;
6073

61-
ArrowTableSlicingCache(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
74+
ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {});
6275

6376
// set caching information externally
64-
void setCaches(std::vector<StringPair>&& bsks, std::vector<StringPair>&& bsksUnsorted = {});
77+
void setCaches(Cache&& bsks, Cache&& bsksUnsorted = {});
6578

6679
// update slicing info cache entry (assumes it is already present)
6780
arrow::Status updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table);
6881
arrow::Status updateCacheEntryUnsorted(int pos, std::shared_ptr<arrow::Table> const& table);
6982

7083
// helper to locate cache position
71-
std::pair<int, bool> getCachePos(StringPair const& bindingKey) const;
72-
int getCachePosSortedFor(StringPair const& bindingKey) const;
73-
int getCachePosUnsortedFor(StringPair const& bindingKey) const;
84+
std::pair<int, bool> getCachePos(Entry const& bindingKey) const;
85+
int getCachePosSortedFor(Entry const& bindingKey) const;
86+
int getCachePosUnsortedFor(Entry const& bindingKey) const;
7487

7588
// get slice from cache for a given value
76-
SliceInfoPtr getCacheFor(StringPair const& bindingKey) const;
77-
SliceInfoUnsortedPtr getCacheUnsortedFor(StringPair const& bindingKey) const;
89+
SliceInfoPtr getCacheFor(Entry const& bindingKey) const;
90+
SliceInfoUnsortedPtr getCacheUnsortedFor(Entry const& bindingKey) const;
7891
SliceInfoPtr getCacheForPos(int pos) const;
7992
SliceInfoUnsortedPtr getCacheUnsortedForPos(int pos) const;
8093

81-
static void validateOrder(StringPair const& bindingKey, std::shared_ptr<arrow::Table> const& input);
94+
static void validateOrder(Entry const& bindingKey, std::shared_ptr<arrow::Table> const& input);
8295
};
8396
} // namespace o2::framework
8497

Framework/Core/include/Framework/GroupSlicer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct GroupSlicer {
5555
{
5656
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
5757
auto binding = o2::soa::getLabelFromTypeForKey<std::decay_t<T>>(mIndexColumnName);
58-
auto bk = std::make_pair(binding, mIndexColumnName);
58+
auto bk = Entry(binding, mIndexColumnName);
5959
if constexpr (!o2::soa::is_smallgroups<std::decay_t<T>>) {
6060
if (table.size() == 0) {
6161
return;

Framework/Core/src/ASoA.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ bool PreslicePolicyBase::isMissing() const
194194
return binding == "[MISSING]";
195195
}
196196

197-
StringPair const& PreslicePolicyBase::getBindingKey() const
197+
Entry const& PreslicePolicyBase::getBindingKey() const
198198
{
199199
return bindingKey;
200200
}

0 commit comments

Comments
 (0)