Skip to content

Commit e24ee88

Browse files
authored
DPL: implement distinction between physical and dataset schema (#13917)
This will come handy to do zero copy, actually.
1 parent b00bfe5 commit e24ee88

File tree

2 files changed

+164
-93
lines changed

2 files changed

+164
-93
lines changed

Framework/AnalysisSupport/src/TTreePlugin.cxx

Lines changed: 119 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "Framework/Signpost.h"
1515
#include "Framework/Endian.h"
1616
#include <arrow/dataset/file_base.h>
17+
#include <arrow/extension_type.h>
18+
#include <arrow/type.h>
1719
#include <arrow/util/key_value_metadata.h>
1820
#include <arrow/array/array_nested.h>
1921
#include <arrow/array/array_primitive.h>
@@ -23,6 +25,8 @@
2325
#include <TBranch.h>
2426
#include <TFile.h>
2527
#include <TLeaf.h>
28+
#include <memory>
29+
#include <iostream>
2630

2731
O2_DECLARE_DYNAMIC_LOG(root_arrow_fs);
2832

@@ -91,6 +95,7 @@ arrow::Result<arrow::fs::FileInfo> SingleTreeFileSystem::GetFileInfo(std::string
9195
return result;
9296
}
9397

98+
// A fragment which holds a tree
9499
class TTreeFileFragment : public arrow::dataset::FileFragment
95100
{
96101
public:
@@ -101,6 +106,13 @@ class TTreeFileFragment : public arrow::dataset::FileFragment
101106
: FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema))
102107
{
103108
}
109+
110+
std::unique_ptr<TTree>& GetTree()
111+
{
112+
auto topFs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source().filesystem());
113+
auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(topFs->GetSubFilesystem(source()));
114+
return treeFs->GetTree(source());
115+
}
104116
};
105117

106118
class TTreeFileFormat : public arrow::dataset::FileFormat
@@ -158,9 +170,9 @@ class TTreeFileFormat : public arrow::dataset::FileFormat
158170
class TTreeOutputStream : public arrow::io::OutputStream
159171
{
160172
public:
161-
// Using a pointer means that the tree itself is owned by another
173+
// Using a pointer means that the tree itself is owned by another
162174
// class
163-
TTreeOutputStream(TTree *, std::string branchPrefix);
175+
TTreeOutputStream(TTree*, std::string branchPrefix);
164176

165177
arrow::Status Close() override;
166178

@@ -245,33 +257,70 @@ struct TTreeObjectReadingImplementation : public RootArrowFactoryPlugin {
245257
}
246258
};
247259

260+
struct BranchFieldMapping {
261+
int mainBranchIdx;
262+
int vlaIdx;
263+
int datasetFieldIdx;
264+
};
265+
248266
arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
249267
const std::shared_ptr<arrow::dataset::ScanOptions>& options,
250268
const std::shared_ptr<arrow::dataset::FileFragment>& fragment) const
251269
{
252-
// Get the fragment as a TTreeFragment. This might be PART of a TTree.
253-
auto treeFragment = std::dynamic_pointer_cast<TTreeFileFragment>(fragment);
254270
// This is the schema we want to read
255271
auto dataset_schema = options->dataset_schema;
256272

257-
auto generator = [pool = options->pool, treeFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize,
273+
auto generator = [pool = options->pool, fragment, dataset_schema, &totalCompressedSize = mTotCompressedSize,
258274
&totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future<std::shared_ptr<arrow::RecordBatch>> {
259-
auto schema = treeFragment->format()->Inspect(treeFragment->source());
260-
261275
std::vector<std::shared_ptr<arrow::Array>> columns;
262276
std::vector<std::shared_ptr<arrow::Field>> fields = dataset_schema->fields();
263-
auto physical_schema = *treeFragment->ReadPhysicalSchema();
277+
auto physical_schema = *fragment->ReadPhysicalSchema();
278+
279+
auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(fragment->source().filesystem());
280+
// Actually get the TTree from the ROOT file.
281+
auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(fs->GetSubFilesystem(fragment->source()));
282+
283+
if (dataset_schema->num_fields() > physical_schema->num_fields()) {
284+
throw runtime_error_f("One TTree must have all the fields requested in a table");
285+
}
286+
287+
// Register physical fields into the cache
288+
std::vector<BranchFieldMapping> mappings;
289+
290+
for (int fi = 0; fi < dataset_schema->num_fields(); ++fi) {
291+
auto dataset_field = dataset_schema->field(fi);
292+
int physicalFieldIdx = physical_schema->GetFieldIndex(dataset_field->name());
293+
294+
if (physicalFieldIdx < 0) {
295+
throw runtime_error_f("Cannot find physical field associated to %s", dataset_field->name().c_str());
296+
}
297+
if (physicalFieldIdx > 1 && physical_schema->field(physicalFieldIdx - 1)->name().ends_with("_size")) {
298+
mappings.push_back({physicalFieldIdx, physicalFieldIdx - 1, fi});
299+
} else {
300+
mappings.push_back({physicalFieldIdx, -1, fi});
301+
}
302+
}
303+
304+
auto& tree = treeFs->GetTree(fragment->source());
305+
tree->SetCacheSize(25000000);
306+
auto branches = tree->GetListOfBranches();
307+
for (auto& mapping : mappings) {
308+
tree->AddBranchToCache((TBranch*)branches->At(mapping.mainBranchIdx), false);
309+
if (mapping.vlaIdx != -1) {
310+
tree->AddBranchToCache((TBranch*)branches->At(mapping.vlaIdx), false);
311+
}
312+
}
313+
tree->StopCacheLearningPhase();
264314

265315
static TBufferFile buffer{TBuffer::EMode::kWrite, 4 * 1024 * 1024};
266-
auto containerFS = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(treeFragment->source().filesystem());
267-
auto fs = std::dynamic_pointer_cast<TTreeFileSystem>(containerFS->GetSubFilesystem(treeFragment->source()));
268316

269317
int64_t rows = -1;
270-
auto& tree = fs->GetTree(treeFragment->source());
271-
for (auto& field : fields) {
318+
for (size_t mi = 0; mi < mappings.size(); ++mi) {
319+
BranchFieldMapping mapping = mappings[mi];
272320
// The field actually on disk
273-
auto physicalField = physical_schema->GetFieldByName(field->name());
274-
TBranch* branch = tree->GetBranch(physicalField->name().c_str());
321+
auto datasetField = dataset_schema->field(mapping.datasetFieldIdx);
322+
auto physicalField = physical_schema->field(mapping.mainBranchIdx);
323+
auto* branch = (TBranch*)branches->At(mapping.mainBranchIdx);
275324
assert(branch);
276325
buffer.Reset();
277326
auto totalEntries = branch->GetEntries();
@@ -284,12 +333,12 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
284333
arrow::Status status;
285334
int readEntries = 0;
286335
std::shared_ptr<arrow::Array> array;
287-
auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type());
288-
if (physicalField->type() == arrow::boolean() ||
289-
(listType && physicalField->type()->field(0)->type() == arrow::boolean())) {
336+
auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(datasetField->type());
337+
if (datasetField->type() == arrow::boolean() ||
338+
(listType && datasetField->type()->field(0)->type() == arrow::boolean())) {
290339
if (listType) {
291340
std::unique_ptr<arrow::ArrayBuilder> builder = nullptr;
292-
auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder);
341+
auto status = arrow::MakeBuilder(pool, datasetField->type()->field(0)->type(), &builder);
293342
if (!status.ok()) {
294343
throw runtime_error("Cannot create value builder");
295344
}
@@ -316,7 +365,7 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
316365
}
317366
} else if (listType == nullptr) {
318367
std::unique_ptr<arrow::ArrayBuilder> builder = nullptr;
319-
auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder);
368+
auto status = arrow::MakeBuilder(pool, datasetField->type(), &builder);
320369
if (!status.ok()) {
321370
throw runtime_error("Cannot create builder");
322371
}
@@ -340,16 +389,14 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
340389
}
341390
}
342391
} else {
343-
// other types: use serialized read to build arrays directly.
344-
auto typeSize = physicalField->type()->byte_width();
345392
// This is needed for branches which have not been persisted.
346393
auto bytes = branch->GetTotBytes();
347394
auto branchSize = bytes ? bytes : 1000000;
348395
auto&& result = arrow::AllocateResizableBuffer(branchSize, pool);
349396
if (!result.ok()) {
350397
throw runtime_error("Cannot allocate values buffer");
351398
}
352-
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
399+
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = result.MoveValueUnsafe();
353400
auto ptr = arrowValuesBuffer->mutable_data();
354401
if (ptr == nullptr) {
355402
throw runtime_error("Invalid buffer");
@@ -363,23 +410,14 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
363410
std::span<int> offsets;
364411
int size = 0;
365412
uint32_t totalSize = 0;
366-
TBranch* mSizeBranch = nullptr;
367-
int64_t listSize = 1;
368-
if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type())) {
369-
listSize = fixedSizeList->list_size();
370-
typeSize = fixedSizeList->field(0)->type()->byte_width();
371-
} else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type())) {
372-
listSize = -1;
373-
typeSize = vlaListType->field(0)->type()->byte_width();
374-
}
375-
if (listSize == -1) {
376-
mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str());
413+
if (mapping.vlaIdx != -1) {
414+
auto* mSizeBranch = (TBranch*)branches->At(mapping.vlaIdx);
377415
offsetBuffer = std::make_unique<TBufferFile>(TBuffer::EMode::kWrite, 4 * 1024 * 1024);
378416
result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool);
379417
if (!result.ok()) {
380418
throw runtime_error("Cannot allocate offset buffer");
381419
}
382-
arrowOffsetBuffer = std::move(result).ValueUnsafe();
420+
arrowOffsetBuffer = result.MoveValueUnsafe();
383421
unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data();
384422
auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
385423
offsets = std::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
@@ -398,9 +436,19 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
398436
readEntries = 0;
399437
}
400438

439+
int typeSize = physicalField->type()->byte_width();
440+
int64_t listSize = 1;
441+
if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(datasetField->type())) {
442+
listSize = fixedSizeList->list_size();
443+
typeSize = physicalField->type()->field(0)->type()->byte_width();
444+
} else if (mapping.vlaIdx != -1) {
445+
typeSize = physicalField->type()->field(0)->type()->byte_width();
446+
listSize = -1;
447+
}
448+
401449
while (readEntries < totalEntries) {
402450
auto readLast = branch->GetBulkRead().GetEntriesSerialized(readEntries, buffer);
403-
if (listSize == -1) {
451+
if (mapping.vlaIdx != -1) {
404452
size = offsets[readEntries + readLast] - offsets[readEntries];
405453
} else {
406454
size = readLast * listSize;
@@ -412,18 +460,15 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
412460
if (listSize >= 1) {
413461
totalSize = readEntries * listSize;
414462
}
415-
std::shared_ptr<arrow::PrimitiveArray> varray;
416-
switch (listSize) {
417-
case -1:
418-
varray = std::make_shared<arrow::PrimitiveArray>(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer);
419-
array = std::make_shared<arrow::ListArray>(physicalField->type(), readEntries, arrowOffsetBuffer, varray);
420-
break;
421-
case 1:
422-
array = std::make_shared<arrow::PrimitiveArray>(physicalField->type(), readEntries, arrowValuesBuffer);
423-
break;
424-
default:
425-
varray = std::make_shared<arrow::PrimitiveArray>(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer);
426-
array = std::make_shared<arrow::FixedSizeListArray>(physicalField->type(), readEntries, varray);
463+
if (listSize == 1) {
464+
array = std::make_shared<arrow::PrimitiveArray>(datasetField->type(), readEntries, arrowValuesBuffer);
465+
} else {
466+
auto varray = std::make_shared<arrow::PrimitiveArray>(datasetField->type()->field(0)->type(), totalSize, arrowValuesBuffer);
467+
if (mapping.vlaIdx != -1) {
468+
array = std::make_shared<arrow::ListArray>(datasetField->type(), readEntries, arrowOffsetBuffer, varray);
469+
} else {
470+
array = std::make_shared<arrow::FixedSizeListArray>(datasetField->type(), readEntries, varray);
471+
}
427472
}
428473
}
429474

@@ -534,9 +579,12 @@ auto arrowTypeFromROOT(EDataType type, int size)
534579
}
535580
}
536581

582+
// This is a datatype for branches which implies
583+
struct RootTransientIndexType : arrow::ExtensionType {
584+
};
585+
537586
arrow::Result<std::shared_ptr<arrow::Schema>> TTreeFileFormat::Inspect(const arrow::dataset::FileSource& source) const
538587
{
539-
arrow::Schema schema{{}};
540588
auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source.filesystem());
541589
// Actually get the TTree from the ROOT file.
542590
auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(fs->GetSubFilesystem(source));
@@ -548,51 +596,37 @@ arrow::Result<std::shared_ptr<arrow::Schema>> TTreeFileFormat::Inspect(const arr
548596
auto branches = tree->GetListOfBranches();
549597
auto n = branches->GetEntries();
550598

551-
std::vector<BranchInfo> branchInfos;
599+
std::vector<std::shared_ptr<arrow::Field>> fields;
600+
601+
bool prevIsSize = false;
552602
for (auto i = 0; i < n; ++i) {
553603
auto branch = static_cast<TBranch*>(branches->At(i));
554-
auto name = std::string{branch->GetName()};
555-
auto pos = name.find("_size");
556-
if (pos != std::string::npos) {
557-
name.erase(pos);
558-
branchInfos.emplace_back(BranchInfo{name, (TBranch*)nullptr, true});
604+
std::string name = branch->GetName();
605+
if (prevIsSize && fields.back()->name() != name + "_size") {
606+
throw runtime_error_f("Unexpected layout for VLA container %s.", branch->GetName());
607+
}
608+
609+
if (name.ends_with("_size")) {
610+
fields.emplace_back(std::make_shared<arrow::Field>(name, arrow::int32()));
611+
prevIsSize = true;
559612
} else {
560-
auto lookup = std::find_if(branchInfos.begin(), branchInfos.end(), [&](BranchInfo const& bi) {
561-
return bi.name == name;
562-
});
563-
if (lookup == branchInfos.end()) {
564-
branchInfos.emplace_back(BranchInfo{name, branch, false});
613+
static TClass* cls;
614+
EDataType type;
615+
branch->GetExpectedType(cls, type);
616+
617+
if (prevIsSize) {
618+
fields.emplace_back(std::make_shared<arrow::Field>(name, arrowTypeFromROOT(type, -1)));
565619
} else {
566-
lookup->ptr = branch;
620+
auto listSize = static_cast<TLeaf*>(branch->GetListOfLeaves()->At(0))->GetLenStatic();
621+
fields.emplace_back(std::make_shared<arrow::Field>(name, arrowTypeFromROOT(type, listSize)));
567622
}
623+
prevIsSize = false;
568624
}
569625
}
570626

571-
std::vector<std::shared_ptr<arrow::Field>> fields;
572-
tree->SetCacheSize(25000000);
573-
for (auto& bi : branchInfos) {
574-
static TClass* cls;
575-
EDataType type;
576-
bi.ptr->GetExpectedType(cls, type);
577-
auto listSize = -1;
578-
if (!bi.mVLA) {
579-
listSize = static_cast<TLeaf*>(bi.ptr->GetListOfLeaves()->At(0))->GetLenStatic();
580-
}
581-
auto field = std::make_shared<arrow::Field>(bi.ptr->GetName(), arrowTypeFromROOT(type, listSize));
582-
fields.push_back(field);
583-
584-
tree->AddBranchToCache(bi.ptr);
585-
if (strncmp(bi.ptr->GetName(), "fIndexArray", strlen("fIndexArray")) == 0) {
586-
std::string sizeBranchName = bi.ptr->GetName();
587-
sizeBranchName += "_size";
588-
auto* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str());
589-
if (sizeBranch) {
590-
tree->AddBranchToCache(sizeBranch);
591-
}
592-
}
627+
if (fields.back()->name().ends_with("_size")) {
628+
throw runtime_error_f("Missing values for VLA indices %s.", fields.back()->name().c_str());
593629
}
594-
tree->StopCacheLearningPhase();
595-
596630
return std::make_shared<arrow::Schema>(fields);
597631
}
598632

@@ -601,9 +635,8 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFragment>> TTreeFileFormat::Ma
601635
arrow::dataset::FileSource source, arrow::compute::Expression partition_expression,
602636
std::shared_ptr<arrow::Schema> physical_schema)
603637
{
604-
std::shared_ptr<arrow::dataset::FileFormat> format = std::make_shared<TTreeFileFormat>(mTotCompressedSize, mTotUncompressedSize);
605638

606-
auto fragment = std::make_shared<TTreeFileFragment>(std::move(source), std::move(format),
639+
auto fragment = std::make_shared<TTreeFileFragment>(std::move(source), std::dynamic_pointer_cast<arrow::dataset::FileFormat>(shared_from_this()),
607640
std::move(partition_expression),
608641
std::move(physical_schema));
609642
return std::dynamic_pointer_cast<arrow::dataset::FileFragment>(fragment);

0 commit comments

Comments
 (0)