Skip to content

Commit 65275d9

Browse files
authored
DPL Analysis: use offset cache for sorted grouping (#14571)
1 parent 3cafea0 commit 65275d9

File tree

4 files changed

+55
-40
lines changed

4 files changed

+55
-40
lines changed

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ namespace o2::framework
2121
using ListVector = std::vector<std::vector<int64_t>>;
2222

2323
struct SliceInfoPtr {
24-
gsl::span<int const> values;
25-
gsl::span<int64_t const> counts;
24+
gsl::span<int64_t const> offsets;
25+
gsl::span<int64_t const> sizes;
2626

2727
std::pair<int64_t, int64_t> getSliceFor(int value) const;
2828
};
@@ -66,6 +66,8 @@ struct ArrowTableSlicingCache {
6666
Cache bindingsKeys;
6767
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
6868
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
69+
std::vector<std::vector<int64_t>> offsets;
70+
std::vector<std::vector<int64_t>> sizes;
6971

7072
Cache bindingsKeysUnsorted;
7173
std::vector<std::vector<int>> valuesUnsorted;

Framework/Core/include/Framework/GroupSlicer.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,7 @@ struct GroupSlicer {
246246
pos = position;
247247
}
248248
// optimized split
249-
auto oc = sliceInfos[index].getSliceFor(pos);
250-
uint64_t offset = oc.first;
251-
auto count = oc.second;
249+
auto [offset, count] = sliceInfos[index].getSliceFor(pos);
252250
auto groupedElementsTable = originalTable.rawSlice(offset, offset + count - 1);
253251
groupedElementsTable.bindInternalIndicesTo(&originalTable);
254252
return groupedElementsTable;

Framework/Core/src/ArrowTableSlicingCache.cxx

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,14 @@ void updatePairList(Cache& list, std::string const& binding, std::string const&
3232
std::pair<int64_t, int64_t> SliceInfoPtr::getSliceFor(int value) const
3333
{
3434
int64_t offset = 0;
35-
if (values.empty()) {
35+
if (offsets.empty()) {
3636
return {offset, 0};
3737
}
38-
int64_t p = static_cast<int64_t>(values.size()) - 1;
39-
while (values[p] < 0) {
40-
--p;
41-
if (p < 0) {
42-
return {offset, 0};
43-
}
44-
}
45-
46-
if (value > values[p]) {
38+
if ((size_t)value >= offsets.size()) {
4739
return {offset, 0};
4840
}
4941

50-
for (auto i = 0U; i < values.size(); ++i) {
51-
if (values[i] == value) {
52-
return {offset, counts[i]};
53-
}
54-
offset += counts[i];
55-
}
56-
return {offset, 0};
42+
return {offsets[value], sizes[value]};
5743
}
5844

5945
gsl::span<const int64_t> SliceInfoUnsortedPtr::getSliceFor(int value) const
@@ -84,6 +70,8 @@ ArrowTableSlicingCache::ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorte
8470
{
8571
values.resize(bindingsKeys.size());
8672
counts.resize(bindingsKeys.size());
73+
offsets.resize(bindingsKeys.size());
74+
sizes.resize(bindingsKeys.size());
8775

8876
valuesUnsorted.resize(bindingsKeysUnsorted.size());
8977
groups.resize(bindingsKeysUnsorted.size());
@@ -97,6 +85,10 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted)
9785
values.resize(bindingsKeys.size());
9886
counts.clear();
9987
counts.resize(bindingsKeys.size());
88+
offsets.clear();
89+
offsets.resize(bindingsKeys.size());
90+
sizes.clear();
91+
sizes.resize(bindingsKeys.size());
10092
valuesUnsorted.clear();
10193
valuesUnsorted.resize(bindingsKeysUnsorted.size());
10294
groups.clear();
@@ -105,9 +97,11 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted)
10597

10698
arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table)
10799
{
100+
values[pos].reset();
101+
counts[pos].reset();
102+
offsets[pos].clear();
103+
sizes[pos].clear();
108104
if (table->num_rows() == 0) {
109-
values[pos].reset();
110-
counts[pos].reset();
111105
return arrow::Status::OK();
112106
}
113107
auto& [b, k, e] = bindingsKeys[pos];
@@ -125,6 +119,31 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<
125119
counts[pos].reset();
126120
values[pos] = std::make_shared<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
127121
counts[pos] = std::make_shared<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
122+
123+
int maxValue = -1;
124+
for (auto i = values[pos]->length() - 1; i >= 0; --i) {
125+
if (values[pos]->Value(i) < 0) {
126+
continue;
127+
} else {
128+
maxValue = values[pos]->Value(i);
129+
break;
130+
}
131+
}
132+
133+
offsets[pos].resize(maxValue + 1);
134+
sizes[pos].resize(maxValue + 1);
135+
std::fill(offsets[pos].begin(), offsets[pos].end(), 0);
136+
std::fill(sizes[pos].begin(), sizes[pos].end(), 0);
137+
int64_t offset = 0;
138+
for (auto i = 0U; i < values[pos]->length(); ++i) {
139+
auto value = values[pos]->Value(i);
140+
auto count = counts[pos]->Value(i);
141+
if (value >= 0) {
142+
offsets[pos][value] = offset;
143+
sizes[pos][value] = count;
144+
}
145+
offset += count;
146+
}
128147
return arrow::Status::OK();
129148
}
130149

@@ -221,14 +240,14 @@ SliceInfoPtr ArrowTableSlicingCache::getCacheForPos(int pos) const
221240
{
222241
if (values[pos] == nullptr && counts[pos] == nullptr) {
223242
return {
224-
{},
225-
{} //
243+
{}, //
244+
{} //
226245
};
227246
}
228247

229248
return {
230-
{reinterpret_cast<int const*>(values[pos]->values()->data()), static_cast<size_t>(values[pos]->length())},
231-
{reinterpret_cast<int64_t const*>(counts[pos]->values()->data()), static_cast<size_t>(counts[pos]->length())} //
249+
gsl::span{offsets[pos].data(), offsets[pos].size()}, //
250+
gsl::span(sizes[pos].data(), sizes[pos].size()) //
232251
};
233252
}
234253

Framework/Core/test/test_GroupSlicer.cxx

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ TEST_CASE("GroupSlicerMismatchedGroups")
245245
if (i == 3 || i == 10 || i == 12 || i == 16 || i == 19) {
246246
continue;
247247
}
248-
for (auto j = 0.f; j < 5; j += 0.5f) {
249-
trksWriter(0, i, 0.5f * j);
248+
for (auto j = 0; j < 10; ++j) {
249+
trksWriter(0, i, 0.5f * (j / 2.));
250250
}
251251
}
252252
auto trkTable = builderT.finalize();
@@ -260,21 +260,19 @@ TEST_CASE("GroupSlicerMismatchedGroups")
260260
auto s = slices.updateCacheEntry(0, trkTable);
261261
o2::framework::GroupSlicer g(e, tt, slices);
262262

263-
auto count = 0;
264263
for (auto& slice : g) {
265264
auto as = slice.associatedTables();
266265
auto gg = slice.groupingElement();
267-
REQUIRE(gg.globalIndex() == count);
266+
REQUIRE(gg.globalIndex() == (int64_t)slice.position);
268267
auto trks = std::get<aod::TrksX>(as);
269-
if (count == 3 || count == 10 || count == 12 || count == 16 || count == 19) {
268+
if (slice.position == 3 || slice.position == 10 || slice.position == 12 || slice.position == 16 || slice.position == 19) {
270269
REQUIRE(trks.size() == 0);
271270
} else {
272271
REQUIRE(trks.size() == 10);
273272
}
274273
for (auto& trk : trks) {
275-
REQUIRE(trk.eventId() == count);
274+
REQUIRE(trk.eventId() == (int64_t)slice.position);
276275
}
277-
++count;
278276
}
279277
}
280278

@@ -299,8 +297,8 @@ TEST_CASE("GroupSlicerMismatchedUnassignedGroups")
299297
++skip;
300298
continue;
301299
}
302-
for (auto j = 0.f; j < 5; j += 0.5f) {
303-
trksWriter(0, i, 0.5f * j);
300+
for (auto j = 0; j < 10; ++j) {
301+
trksWriter(0, i, 0.5f * (j / 2.));
304302
}
305303
}
306304
for (auto i = 0; i < 5; ++i) {
@@ -510,7 +508,7 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex")
510508
{
511509
TableBuilder builderE;
512510
auto evtsWriter = builderE.cursor<aod::Events>();
513-
for (auto i = 0; i < 20; ++i) {
511+
for (auto i = 0; i < 10; ++i) {
514512
evtsWriter(0, i, 0.5f * i, 2.f * i, 3.f * i);
515513
}
516514
auto evtTable = builderE.finalize();
@@ -523,7 +521,6 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex")
523521
std::uniform_int_distribution<> distrib(0, 99);
524522

525523
for (auto i = 0; i < 100; ++i) {
526-
527524
filler[0] = distrib(gen);
528525
filler[1] = distrib(gen);
529526
if (filler[0] > filler[1]) {
@@ -541,7 +538,6 @@ TEST_CASE("GroupSlicerMismatchedUnsortedFilteredGroupsWithSelfIndex")
541538
auto thingsTable = builderT.finalize();
542539

543540
aod::Events e{evtTable};
544-
// aod::Parts p{partsTable};
545541
aod::Things t{thingsTable};
546542
using FilteredParts = soa::Filtered<aod::Parts>;
547543
auto size = distrib(gen);

0 commit comments

Comments
 (0)