Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 119 additions & 86 deletions Framework/AnalysisSupport/src/TTreePlugin.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "Framework/Signpost.h"
#include "Framework/Endian.h"
#include <arrow/dataset/file_base.h>
#include <arrow/extension_type.h>
#include <arrow/type.h>
#include <arrow/util/key_value_metadata.h>
#include <arrow/array/array_nested.h>
#include <arrow/array/array_primitive.h>
Expand All @@ -23,6 +25,8 @@
#include <TBranch.h>
#include <TFile.h>
#include <TLeaf.h>
#include <memory>
#include <iostream>

O2_DECLARE_DYNAMIC_LOG(root_arrow_fs);

Expand Down Expand Up @@ -91,6 +95,7 @@ arrow::Result<arrow::fs::FileInfo> SingleTreeFileSystem::GetFileInfo(std::string
return result;
}

// A fragment which holds a tree
class TTreeFileFragment : public arrow::dataset::FileFragment
{
public:
Expand All @@ -101,6 +106,13 @@ class TTreeFileFragment : public arrow::dataset::FileFragment
: FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema))
{
}

std::unique_ptr<TTree>& GetTree()
{
auto topFs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source().filesystem());
auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(topFs->GetSubFilesystem(source()));
return treeFs->GetTree(source());
}
};

class TTreeFileFormat : public arrow::dataset::FileFormat
Expand Down Expand Up @@ -158,9 +170,9 @@ class TTreeFileFormat : public arrow::dataset::FileFormat
class TTreeOutputStream : public arrow::io::OutputStream
{
public:
// Using a pointer means that the tree itself is owned by another
// Using a pointer means that the tree itself is owned by another
// class
TTreeOutputStream(TTree *, std::string branchPrefix);
TTreeOutputStream(TTree*, std::string branchPrefix);

arrow::Status Close() override;

Expand Down Expand Up @@ -245,33 +257,70 @@ struct TTreeObjectReadingImplementation : public RootArrowFactoryPlugin {
}
};

struct BranchFieldMapping {
int mainBranchIdx;
int vlaIdx;
int datasetFieldIdx;
};

arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
const std::shared_ptr<arrow::dataset::ScanOptions>& options,
const std::shared_ptr<arrow::dataset::FileFragment>& fragment) const
{
// Get the fragment as a TTreeFragment. This might be PART of a TTree.
auto treeFragment = std::dynamic_pointer_cast<TTreeFileFragment>(fragment);
// This is the schema we want to read
auto dataset_schema = options->dataset_schema;

auto generator = [pool = options->pool, treeFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize,
auto generator = [pool = options->pool, fragment, dataset_schema, &totalCompressedSize = mTotCompressedSize,
&totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future<std::shared_ptr<arrow::RecordBatch>> {
auto schema = treeFragment->format()->Inspect(treeFragment->source());

std::vector<std::shared_ptr<arrow::Array>> columns;
std::vector<std::shared_ptr<arrow::Field>> fields = dataset_schema->fields();
auto physical_schema = *treeFragment->ReadPhysicalSchema();
auto physical_schema = *fragment->ReadPhysicalSchema();

auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(fragment->source().filesystem());
// Actually get the TTree from the ROOT file.
auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(fs->GetSubFilesystem(fragment->source()));

if (dataset_schema->num_fields() > physical_schema->num_fields()) {
throw runtime_error_f("One TTree must have all the fields requested in a table");
}

// Register physical fields into the cache
std::vector<BranchFieldMapping> mappings;

for (int fi = 0; fi < dataset_schema->num_fields(); ++fi) {
auto dataset_field = dataset_schema->field(fi);
int physicalFieldIdx = physical_schema->GetFieldIndex(dataset_field->name());

if (physicalFieldIdx < 0) {
throw runtime_error_f("Cannot find physical field associated to %s", dataset_field->name().c_str());
}
if (physicalFieldIdx > 1 && physical_schema->field(physicalFieldIdx - 1)->name().ends_with("_size")) {
mappings.push_back({physicalFieldIdx, physicalFieldIdx - 1, fi});
} else {
mappings.push_back({physicalFieldIdx, -1, fi});
}
}

auto& tree = treeFs->GetTree(fragment->source());
tree->SetCacheSize(25000000);
auto branches = tree->GetListOfBranches();
for (auto& mapping : mappings) {
tree->AddBranchToCache((TBranch*)branches->At(mapping.mainBranchIdx), false);
if (mapping.vlaIdx != -1) {
tree->AddBranchToCache((TBranch*)branches->At(mapping.vlaIdx), false);
}
}
tree->StopCacheLearningPhase();

static TBufferFile buffer{TBuffer::EMode::kWrite, 4 * 1024 * 1024};
auto containerFS = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(treeFragment->source().filesystem());
auto fs = std::dynamic_pointer_cast<TTreeFileSystem>(containerFS->GetSubFilesystem(treeFragment->source()));

int64_t rows = -1;
auto& tree = fs->GetTree(treeFragment->source());
for (auto& field : fields) {
for (size_t mi = 0; mi < mappings.size(); ++mi) {
BranchFieldMapping mapping = mappings[mi];
// The field actually on disk
auto physicalField = physical_schema->GetFieldByName(field->name());
TBranch* branch = tree->GetBranch(physicalField->name().c_str());
auto datasetField = dataset_schema->field(mapping.datasetFieldIdx);
auto physicalField = physical_schema->field(mapping.mainBranchIdx);
auto* branch = (TBranch*)branches->At(mapping.mainBranchIdx);
assert(branch);
buffer.Reset();
auto totalEntries = branch->GetEntries();
Expand All @@ -284,12 +333,12 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
arrow::Status status;
int readEntries = 0;
std::shared_ptr<arrow::Array> array;
auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type());
if (physicalField->type() == arrow::boolean() ||
(listType && physicalField->type()->field(0)->type() == arrow::boolean())) {
auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(datasetField->type());
if (datasetField->type() == arrow::boolean() ||
(listType && datasetField->type()->field(0)->type() == arrow::boolean())) {
if (listType) {
std::unique_ptr<arrow::ArrayBuilder> builder = nullptr;
auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder);
auto status = arrow::MakeBuilder(pool, datasetField->type()->field(0)->type(), &builder);
if (!status.ok()) {
throw runtime_error("Cannot create value builder");
}
Expand All @@ -316,7 +365,7 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
}
} else if (listType == nullptr) {
std::unique_ptr<arrow::ArrayBuilder> builder = nullptr;
auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder);
auto status = arrow::MakeBuilder(pool, datasetField->type(), &builder);
if (!status.ok()) {
throw runtime_error("Cannot create builder");
}
Expand All @@ -340,16 +389,14 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
}
}
} else {
// other types: use serialized read to build arrays directly.
auto typeSize = physicalField->type()->byte_width();
// This is needed for branches which have not been persisted.
auto bytes = branch->GetTotBytes();
auto branchSize = bytes ? bytes : 1000000;
auto&& result = arrow::AllocateResizableBuffer(branchSize, pool);
if (!result.ok()) {
throw runtime_error("Cannot allocate values buffer");
}
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move(result).ValueUnsafe();
std::shared_ptr<arrow::Buffer> arrowValuesBuffer = result.MoveValueUnsafe();
auto ptr = arrowValuesBuffer->mutable_data();
if (ptr == nullptr) {
throw runtime_error("Invalid buffer");
Expand All @@ -363,23 +410,14 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
std::span<int> offsets;
int size = 0;
uint32_t totalSize = 0;
TBranch* mSizeBranch = nullptr;
int64_t listSize = 1;
if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type())) {
listSize = fixedSizeList->list_size();
typeSize = fixedSizeList->field(0)->type()->byte_width();
} else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type())) {
listSize = -1;
typeSize = vlaListType->field(0)->type()->byte_width();
}
if (listSize == -1) {
mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str());
if (mapping.vlaIdx != -1) {
auto* mSizeBranch = (TBranch*)branches->At(mapping.vlaIdx);
offsetBuffer = std::make_unique<TBufferFile>(TBuffer::EMode::kWrite, 4 * 1024 * 1024);
result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool);
if (!result.ok()) {
throw runtime_error("Cannot allocate offset buffer");
}
arrowOffsetBuffer = std::move(result).ValueUnsafe();
arrowOffsetBuffer = result.MoveValueUnsafe();
unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data();
auto* tPtrOffset = reinterpret_cast<int*>(ptrOffset);
offsets = std::span<int>{tPtrOffset, tPtrOffset + totalEntries + 1};
Expand All @@ -398,9 +436,19 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
readEntries = 0;
}

int typeSize = physicalField->type()->byte_width();
int64_t listSize = 1;
if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(datasetField->type())) {
listSize = fixedSizeList->list_size();
typeSize = physicalField->type()->field(0)->type()->byte_width();
} else if (mapping.vlaIdx != -1) {
typeSize = physicalField->type()->field(0)->type()->byte_width();
listSize = -1;
}

while (readEntries < totalEntries) {
auto readLast = branch->GetBulkRead().GetEntriesSerialized(readEntries, buffer);
if (listSize == -1) {
if (mapping.vlaIdx != -1) {
size = offsets[readEntries + readLast] - offsets[readEntries];
} else {
size = readLast * listSize;
Expand All @@ -412,18 +460,15 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
if (listSize >= 1) {
totalSize = readEntries * listSize;
}
std::shared_ptr<arrow::PrimitiveArray> varray;
switch (listSize) {
case -1:
varray = std::make_shared<arrow::PrimitiveArray>(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer);
array = std::make_shared<arrow::ListArray>(physicalField->type(), readEntries, arrowOffsetBuffer, varray);
break;
case 1:
array = std::make_shared<arrow::PrimitiveArray>(physicalField->type(), readEntries, arrowValuesBuffer);
break;
default:
varray = std::make_shared<arrow::PrimitiveArray>(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer);
array = std::make_shared<arrow::FixedSizeListArray>(physicalField->type(), readEntries, varray);
if (listSize == 1) {
array = std::make_shared<arrow::PrimitiveArray>(datasetField->type(), readEntries, arrowValuesBuffer);
} else {
auto varray = std::make_shared<arrow::PrimitiveArray>(datasetField->type()->field(0)->type(), totalSize, arrowValuesBuffer);
if (mapping.vlaIdx != -1) {
array = std::make_shared<arrow::ListArray>(datasetField->type(), readEntries, arrowOffsetBuffer, varray);
} else {
array = std::make_shared<arrow::FixedSizeListArray>(datasetField->type(), readEntries, varray);
}
}
}

Expand Down Expand Up @@ -534,9 +579,12 @@ auto arrowTypeFromROOT(EDataType type, int size)
}
}

// This is a datatype for branches which implies
struct RootTransientIndexType : arrow::ExtensionType {
};

arrow::Result<std::shared_ptr<arrow::Schema>> TTreeFileFormat::Inspect(const arrow::dataset::FileSource& source) const
{
arrow::Schema schema{{}};
auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source.filesystem());
// Actually get the TTree from the ROOT file.
auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(fs->GetSubFilesystem(source));
Expand All @@ -548,51 +596,37 @@ arrow::Result<std::shared_ptr<arrow::Schema>> TTreeFileFormat::Inspect(const arr
auto branches = tree->GetListOfBranches();
auto n = branches->GetEntries();

std::vector<BranchInfo> branchInfos;
std::vector<std::shared_ptr<arrow::Field>> fields;

bool prevIsSize = false;
for (auto i = 0; i < n; ++i) {
auto branch = static_cast<TBranch*>(branches->At(i));
auto name = std::string{branch->GetName()};
auto pos = name.find("_size");
if (pos != std::string::npos) {
name.erase(pos);
branchInfos.emplace_back(BranchInfo{name, (TBranch*)nullptr, true});
std::string name = branch->GetName();
if (prevIsSize && fields.back()->name() != name + "_size") {
throw runtime_error_f("Unexpected layout for VLA container %s.", branch->GetName());
}

if (name.ends_with("_size")) {
fields.emplace_back(std::make_shared<arrow::Field>(name, arrow::int32()));
prevIsSize = true;
} else {
auto lookup = std::find_if(branchInfos.begin(), branchInfos.end(), [&](BranchInfo const& bi) {
return bi.name == name;
});
if (lookup == branchInfos.end()) {
branchInfos.emplace_back(BranchInfo{name, branch, false});
static TClass* cls;
EDataType type;
branch->GetExpectedType(cls, type);

if (prevIsSize) {
fields.emplace_back(std::make_shared<arrow::Field>(name, arrowTypeFromROOT(type, -1)));
} else {
lookup->ptr = branch;
auto listSize = static_cast<TLeaf*>(branch->GetListOfLeaves()->At(0))->GetLenStatic();
fields.emplace_back(std::make_shared<arrow::Field>(name, arrowTypeFromROOT(type, listSize)));
}
prevIsSize = false;
}
}

std::vector<std::shared_ptr<arrow::Field>> fields;
tree->SetCacheSize(25000000);
for (auto& bi : branchInfos) {
static TClass* cls;
EDataType type;
bi.ptr->GetExpectedType(cls, type);
auto listSize = -1;
if (!bi.mVLA) {
listSize = static_cast<TLeaf*>(bi.ptr->GetListOfLeaves()->At(0))->GetLenStatic();
}
auto field = std::make_shared<arrow::Field>(bi.ptr->GetName(), arrowTypeFromROOT(type, listSize));
fields.push_back(field);

tree->AddBranchToCache(bi.ptr);
if (strncmp(bi.ptr->GetName(), "fIndexArray", strlen("fIndexArray")) == 0) {
std::string sizeBranchName = bi.ptr->GetName();
sizeBranchName += "_size";
auto* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str());
if (sizeBranch) {
tree->AddBranchToCache(sizeBranch);
}
}
if (fields.back()->name().ends_with("_size")) {
throw runtime_error_f("Missing values for VLA indices %s.", fields.back()->name().c_str());
}
tree->StopCacheLearningPhase();

return std::make_shared<arrow::Schema>(fields);
}

Expand All @@ -601,9 +635,8 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFragment>> TTreeFileFormat::Ma
arrow::dataset::FileSource source, arrow::compute::Expression partition_expression,
std::shared_ptr<arrow::Schema> physical_schema)
{
std::shared_ptr<arrow::dataset::FileFormat> format = std::make_shared<TTreeFileFormat>(mTotCompressedSize, mTotUncompressedSize);

auto fragment = std::make_shared<TTreeFileFragment>(std::move(source), std::move(format),
auto fragment = std::make_shared<TTreeFileFragment>(std::move(source), std::dynamic_pointer_cast<arrow::dataset::FileFormat>(shared_from_this()),
std::move(partition_expression),
std::move(physical_schema));
return std::dynamic_pointer_cast<arrow::dataset::FileFragment>(fragment);
Expand Down
Loading