Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Framework/Core/include/Framework/CompletionPolicyHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ struct CompletionPolicyHelpers {

/// When any of the parts of the record have been received, consume them.
static CompletionPolicy consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher);

#if __has_include(<fairmq/shmem/Message.h>)
/// When any of the parts which has arrived has a refcount of 1.
static CompletionPolicy consumeWhenAnyZeroCount(const char* name, CompletionPolicy::Matcher matcher);
#endif
/// Default matcher applies for all devices
static CompletionPolicy consumeWhenAny(CompletionPolicy::Matcher matcher = [](auto const&) -> bool { return true; })
{
Expand Down
15 changes: 14 additions & 1 deletion Framework/Core/include/Framework/InputSpan.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class InputSpan
/// index and the buffer associated.
/// @nofPartsGetter is the getter for the number of parts associated with an index
/// @a size is the number of elements in the span.
InputSpan(std::function<DataRef(size_t, size_t)> getter, std::function<size_t(size_t)> nofPartsGetter, size_t size);
InputSpan(std::function<DataRef(size_t, size_t)> getter, std::function<size_t(size_t)> nofPartsGetter, std::function<int(size_t)> refCountGetter, size_t size);

/// @a i-th element of the InputSpan
[[nodiscard]] DataRef get(size_t i, size_t partidx = 0) const
Expand All @@ -66,6 +66,18 @@ class InputSpan
return mNofPartsGetter(i);
}

// Get the refcount for a given part
[[nodiscard]] int getRefCount(size_t i) const
{
if (i >= mSize) {
return 0;
}
if (!mRefCountGetter) {
return -1;
}
return mRefCountGetter(i);
}

/// Number of elements in the InputSpan
[[nodiscard]] size_t size() const
{
Expand Down Expand Up @@ -236,6 +248,7 @@ class InputSpan
private:
std::function<DataRef(size_t, size_t)> mGetter;
std::function<size_t(size_t)> mNofPartsGetter;
std::function<int(size_t)> mRefCountGetter;
size_t mSize;
};

Expand Down
4 changes: 4 additions & 0 deletions Framework/Core/src/CompletionPolicy.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ std::vector<CompletionPolicy>
{
return {
CompletionPolicyHelpers::consumeWhenAllOrdered("internal-dpl-aod-writer"),
#if __has_include(<fairmq/shmem/Message.h>)
CompletionPolicyHelpers::consumeWhenAnyZeroCount("internal-dpl-injected-dummy-sink", [](DeviceSpec const& s) { return s.name.find("internal-dpl-injected-dummy-sink") != std::string::npos; }),
#else
CompletionPolicyHelpers::consumeWhenAny("internal-dpl-injected-dummy-sink", [](DeviceSpec const& s) { return s.name.find("internal-dpl-injected-dummy-sink") != std::string::npos; }),
#endif
CompletionPolicyHelpers::consumeWhenAll()};
}

Expand Down
18 changes: 18 additions & 0 deletions Framework/Core/src/CompletionPolicyHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include "Framework/TimingInfo.h"
#include "DecongestionService.h"
#include "Framework/Signpost.h"
#if __has_include(<fairmq/shmem/Message.h>)
#include <fairmq/shmem/Message.h>
#endif

#include <cassert>
#include <regex>
Expand Down Expand Up @@ -249,6 +252,21 @@ CompletionPolicy CompletionPolicyHelpers::consumeExistingWhenAny(const char* nam
}};
}

#if __has_include(<fairmq/shmem/Message.h>)
CompletionPolicy CompletionPolicyHelpers::consumeWhenAnyZeroCount(const char* name, CompletionPolicy::Matcher matcher)
{
auto callback = [](InputSpan const& inputs, std::vector<InputSpec> const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp {
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs.get(i).header != nullptr && inputs.getRefCount(i) == 1) {
return CompletionPolicy::CompletionOp::Consume;
}
}
return CompletionPolicy::CompletionOp::Wait;
};
return CompletionPolicy{name, matcher, callback, false};
}
#endif

CompletionPolicy CompletionPolicyHelpers::consumeWhenAny(const char* name, CompletionPolicy::Matcher matcher)
{
auto callback = [](InputSpan const& inputs, std::vector<InputSpec> const&, ServiceRegistryRef& ref) -> CompletionPolicy::CompletionOp {
Expand Down
15 changes: 14 additions & 1 deletion Framework/Core/src/DataProcessingDevice.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
#include <fairmq/Parts.h>
#include <fairmq/Socket.h>
#include <fairmq/ProgOptions.h>
#if __has_include(<fairmq/shmem/Message.h>)
#include <fairmq/shmem/Message.h>
#endif
#include <Configuration/ConfigurationInterface.h>
#include <Configuration/ConfigurationFactory.h>
#include <Monitoring/Monitoring.h>
Expand Down Expand Up @@ -1214,12 +1217,14 @@ void DataProcessingDevice::fillContext(DataProcessorContext& context, DeviceCont
if (forwarded.matcher.lifetime != Lifetime::Condition) {
onlyConditions = false;
}
#if !__has_include(<fairmq/shmem/Message.h>)
if (strncmp(DataSpecUtils::asConcreteOrigin(forwarded.matcher).str, "AOD", 3) == 0) {
context.canForwardEarly = false;
overriddenEarlyForward = true;
LOG(detail) << "Cannot forward early because of AOD input: " << DataSpecUtils::describe(forwarded.matcher);
break;
}
#endif
if (DataSpecUtils::partialMatch(forwarded.matcher, o2::header::DataDescription{"RAWDATA"}) && mProcessingPolicies.earlyForward == EarlyForwardPolicy::NORAW) {
context.canForwardEarly = false;
overriddenEarlyForward = true;
Expand Down Expand Up @@ -2230,7 +2235,15 @@ bool DataProcessingDevice::tryDispatchComputation(ServiceRegistryRef ref, std::v
auto nofPartsGetter = [&currentSetOfInputs](size_t i) -> size_t {
return currentSetOfInputs[i].getNumberOfPairs();
};
return InputSpan{getter, nofPartsGetter, currentSetOfInputs.size()};
#if __has_include(<fairmq/shmem/Message.h>)
auto refCountGetter = [&currentSetOfInputs](size_t idx) -> int {
auto& header = static_cast<const fair::mq::shmem::Message&>(*currentSetOfInputs[idx].header(0));
return header.GetRefCount();
};
#else
std::function<int(size_t)> refCountGetter = nullptr;
#endif
return InputSpan{getter, nofPartsGetter, refCountGetter, currentSetOfInputs.size()};
};

auto markInputsAsDone = [ref](TimesliceSlot slot) -> void {
Expand Down
24 changes: 22 additions & 2 deletions Framework/Core/src/DataRelayer.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
#include <Monitoring/Monitoring.h>

#include <fairmq/Channel.h>
#include <functional>
#if __has_include(<fairmq/shmem/Message.h>)
#include <fairmq/shmem/Message.h>
#endif
#include <fmt/format.h>
#include <fmt/ostream.h>
#include <gsl/span>
Expand Down Expand Up @@ -209,7 +213,15 @@ DataRelayer::ActivityStats DataRelayer::processDanglingInputs(std::vector<Expira
auto nPartsGetter = [&partial](size_t idx) {
return partial[idx].size();
};
InputSpan span{getter, nPartsGetter, static_cast<size_t>(partial.size())};
#if __has_include(<fairmq/shmem/Message.h>)
auto refCountGetter = [&partial](size_t idx) -> int {
auto& header = static_cast<const fair::mq::shmem::Message&>(*partial[idx].header(0));
return header.GetRefCount();
};
#else
std::function<int(size_t)> refCountGetter = nullptr;
#endif
InputSpan span{getter, nPartsGetter, refCountGetter, static_cast<size_t>(partial.size())};
// Setup the input span

if (expirator.checker(services, timestamp.value, span) == false) {
Expand Down Expand Up @@ -755,7 +767,15 @@ void DataRelayer::getReadyToProcess(std::vector<DataRelayer::RecordAction>& comp
auto nPartsGetter = [&partial](size_t idx) {
return partial[idx].size();
};
InputSpan span{getter, nPartsGetter, static_cast<size_t>(partial.size())};
#if __has_include(<fairmq/shmem/Message.h>)
auto refCountGetter = [&partial](size_t idx) -> int {
auto& header = static_cast<const fair::mq::shmem::Message&>(*partial[idx].header(0));
return header.GetRefCount();
};
#else
std::function<int(size_t)> refCountGetter = nullptr;
#endif
InputSpan span{getter, nPartsGetter, refCountGetter, static_cast<size_t>(partial.size())};
CompletionPolicy::CompletionOp action = mCompletionPolicy.callbackFull(span, mInputs, mContext);

auto& variables = mTimesliceIndex.getVariablesForSlot(slot);
Expand Down
7 changes: 5 additions & 2 deletions Framework/Core/src/InputSpan.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ InputSpan::InputSpan(std::function<DataRef(size_t, size_t)> getter, size_t size)
{
}

InputSpan::InputSpan(std::function<DataRef(size_t, size_t)> getter, std::function<size_t(size_t)> nofPartsGetter, size_t size)
: mGetter{getter}, mNofPartsGetter{nofPartsGetter}, mSize{size}
InputSpan::InputSpan(std::function<DataRef(size_t, size_t)> getter,
std::function<size_t(size_t)> nofPartsGetter,
std::function<int(size_t)> refCountGetter,
size_t size)
: mGetter{getter}, mNofPartsGetter{nofPartsGetter}, mRefCountGetter(refCountGetter), mSize{size}
{
}

Expand Down
2 changes: 1 addition & 1 deletion Framework/Core/test/test_InputRecordWalker.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct DataSet {
auto payload = static_cast<char const*>(this->messages[i].second.at(2 * part + 1)->data());
return DataRef{nullptr, header, payload};
},
[this](size_t i) { return i < this->messages.size() ? messages[i].second.size() / 2 : 0; }, this->messages.size()},
[this](size_t i) { return i < this->messages.size() ? messages[i].second.size() / 2 : 0; }, nullptr, this->messages.size()},
record{schema, span, registry},
values{std::move(v)}
{
Expand Down
2 changes: 1 addition & 1 deletion Framework/Core/test/test_InputSpan.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ TEST_CASE("TestInputSpan")
return inputs[i].size() / 2;
};

InputSpan span{getter, nPartsGetter, inputs.size()};
InputSpan span{getter, nPartsGetter, nullptr, inputs.size()};
REQUIRE(span.size() == inputs.size());
routeNo = 0;
for (; routeNo < span.size(); ++routeNo) {
Expand Down
6 changes: 4 additions & 2 deletions Framework/Utils/test/RawPageTestData.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ struct DataSet {
auto payload = static_cast<char const*>(this->messages[i].at(2 * part + 1)->data());
return DataRef{nullptr, header, payload};
},
[this](size_t i) { return i < this->messages.size() ? messages[i].size() / 2 : 0; }, this->messages.size()},
[this](size_t i) { return i < this->messages.size() ? messages[i].size() / 2 : 0; },
nullptr,
this->messages.size()},
record{schema, span, registry},
values{std::move(v)}
{
Expand All @@ -63,5 +65,5 @@ struct DataSet {
using AmendRawDataHeader = std::function<void(RAWDataHeader&)>;
DataSet createData(std::vector<InputSpec> const& inputspecs, std::vector<DataHeader> const& dataheaders, AmendRawDataHeader amendRdh = nullptr);

} // namespace o2::framework
} // namespace o2::framework::test
#endif // FRAMEWORK_UTILS_RAWPAGETESTDATA_H