Skip to content

Commit 90267bb

Browse files
authored
DPL: enable early forwarding for AODs (#14088)
Should improve parallelism for long trains. Requires FairMQ 1.9.2 and one needs to pass `--early-forwarding-policy always` for this to take effect.
1 parent 9046e70 commit 90267bb

File tree

10 files changed

+88
-10
lines changed

10 files changed

+88
-10
lines changed

Framework/Core/include/Framework/CompletionPolicyHelpers.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ struct CompletionPolicyHelpers {
4343

4444
/// When any of the parts of the record have been received, consume them.
4545
static CompletionPolicy consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher);
46+
47+
#if __has_include(<fairmq/shmem/Message.h>)
48+
/// When any of the parts which has arrived has a refcount of 1.
49+
static CompletionPolicy consumeWhenAnyZeroCount(const char* name, CompletionPolicy::Matcher matcher);
50+
#endif
4651
/// Default matcher applies for all devices
4752
static CompletionPolicy consumeWhenAny(CompletionPolicy::Matcher matcher = [](auto const&) -> bool { return true; })
4853
{

Framework/Core/include/Framework/InputSpan.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class InputSpan
4646
/// index and the buffer associated.
4747
/// @nofPartsGetter is the getter for the number of parts associated with an index
4848
/// @a size is the number of elements in the span.
49-
InputSpan(std::function<DataRef(size_t, size_t)> getter, std::function<size_t(size_t)> nofPartsGetter, size_t size);
49+
InputSpan(std::function<DataRef(size_t, size_t)> getter, std::function<size_t(size_t)> nofPartsGetter, std::function<int(size_t)> refCountGetter, size_t size);
5050

5151
/// @a i-th element of the InputSpan
5252
[[nodiscard]] DataRef get(size_t i, size_t partidx = 0) const
@@ -66,6 +66,18 @@ class InputSpan
6666
return mNofPartsGetter(i);
6767
}
6868

69+
// Get the refcount for a given part
70+
[[nodiscard]] int getRefCount(size_t i) const
71+
{
72+
if (i >= mSize) {
73+
return 0;
74+
}
75+
if (!mRefCountGetter) {
76+
return -1;
77+
}
78+
return mRefCountGetter(i);
79+
}
80+
6981
/// Number of elements in the InputSpan
7082
[[nodiscard]] size_t size() const
7183
{
@@ -236,6 +248,7 @@ class InputSpan
236248
private:
237249
std::function<DataRef(size_t, size_t)> mGetter;
238250
std::function<size_t(size_t)> mNofPartsGetter;
251+
std::function<int(size_t)> mRefCountGetter;
239252
size_t mSize;
240253
};
241254

Framework/Core/src/CompletionPolicy.cxx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ std::vector<CompletionPolicy>
2626
{
2727
return {
2828
CompletionPolicyHelpers::consumeWhenAllOrdered("internal-dpl-aod-writer"),
29+
#if __has_include(<fairmq/shmem/Message.h>)
30+
CompletionPolicyHelpers::consumeWhenAnyZeroCount("internal-dpl-injected-dummy-sink", [](DeviceSpec const& s) { return s.name.find("internal-dpl-injected-dummy-sink") != std::string::npos; }),
31+
#else
2932
CompletionPolicyHelpers::consumeWhenAny("internal-dpl-injected-dummy-sink", [](DeviceSpec const& s) { return s.name.find("internal-dpl-injected-dummy-sink") != std::string::npos; }),
33+
#endif
3034
CompletionPolicyHelpers::consumeWhenAll()};
3135
}
3236

Framework/Core/src/CompletionPolicyHelpers.cxx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include "Framework/TimingInfo.h"
2020
#include "DecongestionService.h"
2121
#include "Framework/Signpost.h"
22+
#if __has_include(<fairmq/shmem/Message.h>)
23+
#include <fairmq/shmem/Message.h>
24+
#endif
2225

2326
#include <cassert>
2427
#include <regex>
@@ -249,6 +252,21 @@ CompletionPolicy CompletionPolicyHelpers::consumeExistingWhenAny(const char* nam
249252
}};
250253
}
251254

255+
#if __has_include(<fairmq/shmem/Message.h>)
256+
CompletionPolicy CompletionPolicyHelpers::consumeWhenAnyZeroCount(const char* name, CompletionPolicy::Matcher matcher)
257+
{
258+
auto callback = [](InputSpan const& inputs, std::vector<InputSpec> const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp {
259+
for (size_t i = 0; i < inputs.size(); ++i) {
260+
if (inputs.get(i).header != nullptr && inputs.getRefCount(i) == 1) {
261+
return CompletionPolicy::CompletionOp::Consume;
262+
}
263+
}
264+
return CompletionPolicy::CompletionOp::Wait;
265+
};
266+
return CompletionPolicy{name, matcher, callback, false};
267+
}
268+
#endif
269+
252270
CompletionPolicy CompletionPolicyHelpers::consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher)
253271
{
254272
auto callback = [](InputSpan const& inputs, std::vector<InputSpec> const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp {

Framework/Core/src/DataProcessingDevice.cxx

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
#include <fairmq/Parts.h>
5858
#include <fairmq/Socket.h>
5959
#include <fairmq/ProgOptions.h>
60+
#if __has_include(<fairmq/shmem/Message.h>)
61+
#include <fairmq/shmem/Message.h>
62+
#endif
6063
#include <Configuration/ConfigurationInterface.h>
6164
#include <Configuration/ConfigurationFactory.h>
6265
#include <Monitoring/Monitoring.h>
@@ -1214,12 +1217,14 @@ void DataProcessingDevice::fillContext(DataProcessorContext& context, DeviceCont
12141217
if (forwarded.matcher.lifetime != Lifetime::Condition) {
12151218
onlyConditions = false;
12161219
}
1220+
#if !__has_include(<fairmq/shmem/Message.h>)
12171221
if (strncmp(DataSpecUtils::asConcreteOrigin(forwarded.matcher).str, "AOD", 3) == 0) {
12181222
context.canForwardEarly = false;
12191223
overriddenEarlyForward = true;
12201224
LOG(detail) << "Cannot forward early because of AOD input: " << DataSpecUtils::describe(forwarded.matcher);
12211225
break;
12221226
}
1227+
#endif
12231228
if (DataSpecUtils::partialMatch(forwarded.matcher, o2::header::DataDescription{"RAWDATA"}) && mProcessingPolicies.earlyForward == EarlyForwardPolicy::NORAW) {
12241229
context.canForwardEarly = false;
12251230
overriddenEarlyForward = true;
@@ -2230,7 +2235,15 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v
22302235
auto nofPartsGetter = [&currentSetOfInputs](size_t i) -> size_t {
22312236
return currentSetOfInputs[i].getNumberOfPairs();
22322237
};
2233-
return InputSpan{getter, nofPartsGetter, currentSetOfInputs.size()};
2238+
#if __has_include(<fairmq/shmem/Message.h>)
2239+
auto refCountGetter = [&currentSetOfInputs](size_t idx) -> int {
2240+
auto& header = static_cast<const fair::mq::shmem::Message&>(*currentSetOfInputs[idx].header(0));
2241+
return header.GetRefCount();
2242+
};
2243+
#else
2244+
std::function<int(size_t)> refCountGetter = nullptr;
2245+
#endif
2246+
return InputSpan{getter, nofPartsGetter, refCountGetter, currentSetOfInputs.size()};
22342247
};
22352248

22362249
auto markInputsAsDone = [ref](TimesliceSlot slot) -> void {

Framework/Core/src/DataRelayer.cxx

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
#include <Monitoring/Monitoring.h>
4545

4646
#include <fairmq/Channel.h>
47+
#include <functional>
48+
#if __has_include(<fairmq/shmem/Message.h>)
49+
#include <fairmq/shmem/Message.h>
50+
#endif
4751
#include <fmt/format.h>
4852
#include <fmt/ostream.h>
4953
#include <gsl/span>
@@ -209,7 +213,15 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector<Expira
209213
auto nPartsGetter = [&partial](size_t idx) {
210214
return partial[idx].size();
211215
};
212-
InputSpan span{getter, nPartsGetter, static_cast<size_t>(partial.size())};
216+
#if __has_include(<fairmq/shmem/Message.h>)
217+
auto refCountGetter = [&partial](size_t idx) -> int {
218+
auto& header = static_cast<const fair::mq::shmem::Message&>(*partial[idx].header(0));
219+
return header.GetRefCount();
220+
};
221+
#else
222+
std::function<int(size_t)> refCountGetter = nullptr;
223+
#endif
224+
InputSpan span{getter, nPartsGetter, refCountGetter, static_cast<size_t>(partial.size())};
213225
// Setup the input span
214226

215227
if (expirator.checker(services, timestamp.value, span) == false) {
@@ -755,7 +767,15 @@ void DataRelayer::getReadyToProcess(std::vector<DataRelayer::RecordAction>& comp
755767
auto nPartsGetter = [&partial](size_t idx) {
756768
return partial[idx].size();
757769
};
758-
InputSpan span{getter, nPartsGetter, static_cast<size_t>(partial.size())};
770+
#if __has_include(<fairmq/shmem/Message.h>)
771+
auto refCountGetter = [&partial](size_t idx) -> int {
772+
auto& header = static_cast<const fair::mq::shmem::Message&>(*partial[idx].header(0));
773+
return header.GetRefCount();
774+
};
775+
#else
776+
std::function<int(size_t)> refCountGetter = nullptr;
777+
#endif
778+
InputSpan span{getter, nPartsGetter, refCountGetter, static_cast<size_t>(partial.size())};
759779
CompletionPolicy::CompletionOp action = mCompletionPolicy.callbackFull(span, mInputs, mContext);
760780

761781
auto& variables = mTimesliceIndex.getVariablesForSlot(slot);

Framework/Core/src/InputSpan.cxx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ InputSpan::InputSpan(std::function<DataRef(size_t, size_t)> getter, size_t size)
2929
{
3030
}
3131

32-
InputSpan::InputSpan(std::function<DataRef(size_t, size_t)> getter, std::function<size_t(size_t)> nofPartsGetter, size_t size)
33-
: mGetter{getter}, mNofPartsGetter{nofPartsGetter}, mSize{size}
32+
InputSpan::InputSpan(std::function<DataRef(size_t, size_t)> getter,
33+
std::function<size_t(size_t)> nofPartsGetter,
34+
std::function<int(size_t)> refCountGetter,
35+
size_t size)
36+
: mGetter{getter}, mNofPartsGetter{nofPartsGetter}, mRefCountGetter(refCountGetter), mSize{size}
3437
{
3538
}
3639

Framework/Core/test/test_InputRecordWalker.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct DataSet {
4242
auto payload = static_cast<char const*>(this->messages[i].second.at(2 * part + 1)->data());
4343
return DataRef{nullptr, header, payload};
4444
},
45-
[this](size_t i) { return i < this->messages.size() ? messages[i].second.size() / 2 : 0; }, this->messages.size()},
45+
[this](size_t i) { return i < this->messages.size() ? messages[i].second.size() / 2 : 0; }, nullptr, this->messages.size()},
4646
record{schema, span, registry},
4747
values{std::move(v)}
4848
{

Framework/Core/test/test_InputSpan.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ TEST_CASE("TestInputSpan")
3737
return inputs[i].size() / 2;
3838
};
3939

40-
InputSpan span{getter, nPartsGetter, inputs.size()};
40+
InputSpan span{getter, nPartsGetter, nullptr, inputs.size()};
4141
REQUIRE(span.size() == inputs.size());
4242
routeNo = 0;
4343
for (; routeNo < span.size(); ++routeNo) {

Framework/Utils/test/RawPageTestData.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ struct DataSet {
4747
auto payload = static_cast<char const*>(this->messages[i].at(2 * part + 1)->data());
4848
return DataRef{nullptr, header, payload};
4949
},
50-
[this](size_t i) { return i < this->messages.size() ? messages[i].size() / 2 : 0; }, this->messages.size()},
50+
[this](size_t i) { return i < this->messages.size() ? messages[i].size() / 2 : 0; },
51+
nullptr,
52+
this->messages.size()},
5153
record{schema, span, registry},
5254
values{std::move(v)}
5355
{
@@ -63,5 +65,5 @@ struct DataSet {
6365
using AmendRawDataHeader = std::function<void(RAWDataHeader&)>;
6466
DataSet createData(std::vector<InputSpec> const& inputspecs, std::vector<DataHeader> const& dataheaders, AmendRawDataHeader amendRdh = nullptr);
6567

66-
} // namespace o2::framework
68+
} // namespace o2::framework::test
6769
#endif // FRAMEWORK_UTILS_RAWPAGETESTDATA_H

0 commit comments

Comments
 (0)