1414#include " Framework/Signpost.h"
1515#include " Framework/Endian.h"
1616#include < arrow/dataset/file_base.h>
17+ #include < arrow/extension_type.h>
18+ #include < arrow/type.h>
1719#include < arrow/util/key_value_metadata.h>
1820#include < arrow/array/array_nested.h>
1921#include < arrow/array/array_primitive.h>
2325#include < TBranch.h>
2426#include < TFile.h>
2527#include < TLeaf.h>
28+ #include < memory>
29+ #include < iostream>
2630
2731O2_DECLARE_DYNAMIC_LOG (root_arrow_fs);
2832
@@ -91,6 +95,7 @@ arrow::Result<arrow::fs::FileInfo> SingleTreeFileSystem::GetFileInfo(std::string
9195 return result;
9296}
9397
98+ // A fragment which holds a tree
9499class TTreeFileFragment : public arrow ::dataset::FileFragment
95100{
96101 public:
@@ -101,6 +106,13 @@ class TTreeFileFragment : public arrow::dataset::FileFragment
101106 : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema))
102107 {
103108 }
109+
110+ std::unique_ptr<TTree>& GetTree ()
111+ {
112+ auto topFs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source ().filesystem ());
113+ auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(topFs->GetSubFilesystem (source ()));
114+ return treeFs->GetTree (source ());
115+ }
104116};
105117
106118class TTreeFileFormat : public arrow ::dataset::FileFormat
@@ -158,9 +170,9 @@ class TTreeFileFormat : public arrow::dataset::FileFormat
158170class TTreeOutputStream : public arrow ::io::OutputStream
159171{
160172 public:
161- // Using a pointer means that the tree itself is owned by another
173+ // Using a pointer means that the tree itself is owned by another
162174 // class
163- TTreeOutputStream (TTree *, std::string branchPrefix);
175+ TTreeOutputStream (TTree*, std::string branchPrefix);
164176
165177 arrow::Status Close () override ;
166178
@@ -245,33 +257,70 @@ struct TTreeObjectReadingImplementation : public RootArrowFactoryPlugin {
245257 }
246258};
247259
260+ struct BranchFieldMapping {
261+ int mainBranchIdx;
262+ int vlaIdx;
263+ int datasetFieldIdx;
264+ };
265+
248266arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync (
249267 const std::shared_ptr<arrow::dataset::ScanOptions>& options,
250268 const std::shared_ptr<arrow::dataset::FileFragment>& fragment) const
251269{
252- // Get the fragment as a TTreeFragment. This might be PART of a TTree.
253- auto treeFragment = std::dynamic_pointer_cast<TTreeFileFragment>(fragment);
254270 // This is the schema we want to read
255271 auto dataset_schema = options->dataset_schema ;
256272
257- auto generator = [pool = options->pool , treeFragment , dataset_schema, &totalCompressedSize = mTotCompressedSize ,
273+ auto generator = [pool = options->pool , fragment , dataset_schema, &totalCompressedSize = mTotCompressedSize ,
258274 &totalUncompressedSize = mTotUncompressedSize ]() -> arrow::Future<std::shared_ptr<arrow::RecordBatch>> {
259- auto schema = treeFragment->format ()->Inspect (treeFragment->source ());
260-
261275 std::vector<std::shared_ptr<arrow::Array>> columns;
262276 std::vector<std::shared_ptr<arrow::Field>> fields = dataset_schema->fields ();
263- auto physical_schema = *treeFragment->ReadPhysicalSchema ();
277+ auto physical_schema = *fragment->ReadPhysicalSchema ();
278+
279+ auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(fragment->source ().filesystem ());
280+ // Actually get the TTree from the ROOT file.
281+ auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(fs->GetSubFilesystem (fragment->source ()));
282+
283+ if (dataset_schema->num_fields () > physical_schema->num_fields ()) {
284+ throw runtime_error_f (" One TTree must have all the fields requested in a table" );
285+ }
286+
287+ // Register physical fields into the cache
288+ std::vector<BranchFieldMapping> mappings;
289+
290+ for (int fi = 0 ; fi < dataset_schema->num_fields (); ++fi) {
291+ auto dataset_field = dataset_schema->field (fi);
292+ int physicalFieldIdx = physical_schema->GetFieldIndex (dataset_field->name ());
293+
294+ if (physicalFieldIdx < 0 ) {
295+ throw runtime_error_f (" Cannot find physical field associated to %s" , dataset_field->name ().c_str ());
296+ }
297+ if (physicalFieldIdx > 1 && physical_schema->field (physicalFieldIdx - 1 )->name ().ends_with (" _size" )) {
298+ mappings.push_back ({physicalFieldIdx, physicalFieldIdx - 1 , fi});
299+ } else {
300+ mappings.push_back ({physicalFieldIdx, -1 , fi});
301+ }
302+ }
303+
304+ auto & tree = treeFs->GetTree (fragment->source ());
305+ tree->SetCacheSize (25000000 );
306+ auto branches = tree->GetListOfBranches ();
307+ for (auto & mapping : mappings) {
308+ tree->AddBranchToCache ((TBranch*)branches->At (mapping.mainBranchIdx ), false );
309+ if (mapping.vlaIdx != -1 ) {
310+ tree->AddBranchToCache ((TBranch*)branches->At (mapping.vlaIdx ), false );
311+ }
312+ }
313+ tree->StopCacheLearningPhase ();
264314
265315 static TBufferFile buffer{TBuffer::EMode::kWrite , 4 * 1024 * 1024 };
266- auto containerFS = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(treeFragment->source ().filesystem ());
267- auto fs = std::dynamic_pointer_cast<TTreeFileSystem>(containerFS->GetSubFilesystem (treeFragment->source ()));
268316
269317 int64_t rows = -1 ;
270- auto & tree = fs-> GetTree (treeFragment-> source ());
271- for ( auto & field : fields) {
318+ for ( size_t mi = 0 ; mi < mappings. size (); ++mi) {
319+ BranchFieldMapping mapping = mappings[mi];
272320 // The field actually on disk
273- auto physicalField = physical_schema->GetFieldByName (field->name ());
274- TBranch* branch = tree->GetBranch (physicalField->name ().c_str ());
321+ auto datasetField = dataset_schema->field (mapping.datasetFieldIdx );
322+ auto physicalField = physical_schema->field (mapping.mainBranchIdx );
323+ auto * branch = (TBranch*)branches->At (mapping.mainBranchIdx );
275324 assert (branch);
276325 buffer.Reset ();
277326 auto totalEntries = branch->GetEntries ();
@@ -284,12 +333,12 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
284333 arrow::Status status;
285334 int readEntries = 0 ;
286335 std::shared_ptr<arrow::Array> array;
287- auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField ->type ());
288- if (physicalField ->type () == arrow::boolean () ||
289- (listType && physicalField ->type ()->field (0 )->type () == arrow::boolean ())) {
336+ auto listType = std::dynamic_pointer_cast<arrow::FixedSizeListType>(datasetField ->type ());
337+ if (datasetField ->type () == arrow::boolean () ||
338+ (listType && datasetField ->type ()->field (0 )->type () == arrow::boolean ())) {
290339 if (listType) {
291340 std::unique_ptr<arrow::ArrayBuilder> builder = nullptr ;
292- auto status = arrow::MakeBuilder (pool, physicalField ->type ()->field (0 )->type (), &builder);
341+ auto status = arrow::MakeBuilder (pool, datasetField ->type ()->field (0 )->type (), &builder);
293342 if (!status.ok ()) {
294343 throw runtime_error (" Cannot create value builder" );
295344 }
@@ -316,7 +365,7 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
316365 }
317366 } else if (listType == nullptr ) {
318367 std::unique_ptr<arrow::ArrayBuilder> builder = nullptr ;
319- auto status = arrow::MakeBuilder (pool, physicalField ->type (), &builder);
368+ auto status = arrow::MakeBuilder (pool, datasetField ->type (), &builder);
320369 if (!status.ok ()) {
321370 throw runtime_error (" Cannot create builder" );
322371 }
@@ -340,16 +389,14 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
340389 }
341390 }
342391 } else {
343- // other types: use serialized read to build arrays directly.
344- auto typeSize = physicalField->type ()->byte_width ();
345392 // This is needed for branches which have not been persisted.
346393 auto bytes = branch->GetTotBytes ();
347394 auto branchSize = bytes ? bytes : 1000000 ;
348395 auto && result = arrow::AllocateResizableBuffer (branchSize, pool);
349396 if (!result.ok ()) {
350397 throw runtime_error (" Cannot allocate values buffer" );
351398 }
352- std::shared_ptr<arrow::Buffer> arrowValuesBuffer = std::move ( result). ValueUnsafe ();
399+ std::shared_ptr<arrow::Buffer> arrowValuesBuffer = result. MoveValueUnsafe ();
353400 auto ptr = arrowValuesBuffer->mutable_data ();
354401 if (ptr == nullptr ) {
355402 throw runtime_error (" Invalid buffer" );
@@ -363,23 +410,14 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
363410 std::span<int > offsets;
364411 int size = 0 ;
365412 uint32_t totalSize = 0 ;
366- TBranch* mSizeBranch = nullptr ;
367- int64_t listSize = 1 ;
368- if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(physicalField->type ())) {
369- listSize = fixedSizeList->list_size ();
370- typeSize = fixedSizeList->field (0 )->type ()->byte_width ();
371- } else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type ())) {
372- listSize = -1 ;
373- typeSize = vlaListType->field (0 )->type ()->byte_width ();
374- }
375- if (listSize == -1 ) {
376- mSizeBranch = branch->GetTree ()->GetBranch ((std::string{branch->GetName ()} + " _size" ).c_str ());
413+ if (mapping.vlaIdx != -1 ) {
414+ auto * mSizeBranch = (TBranch*)branches->At (mapping.vlaIdx );
377415 offsetBuffer = std::make_unique<TBufferFile>(TBuffer::EMode::kWrite , 4 * 1024 * 1024 );
378416 result = arrow::AllocateResizableBuffer ((totalEntries + 1 ) * (int64_t )sizeof (int ), pool);
379417 if (!result.ok ()) {
380418 throw runtime_error (" Cannot allocate offset buffer" );
381419 }
382- arrowOffsetBuffer = std::move ( result). ValueUnsafe ();
420+ arrowOffsetBuffer = result. MoveValueUnsafe ();
383421 unsigned char * ptrOffset = arrowOffsetBuffer->mutable_data ();
384422 auto * tPtrOffset = reinterpret_cast <int *>(ptrOffset);
385423 offsets = std::span<int >{tPtrOffset, tPtrOffset + totalEntries + 1 };
@@ -398,9 +436,19 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
398436 readEntries = 0 ;
399437 }
400438
439+ int typeSize = physicalField->type ()->byte_width ();
440+ int64_t listSize = 1 ;
441+ if (auto fixedSizeList = std::dynamic_pointer_cast<arrow::FixedSizeListType>(datasetField->type ())) {
442+ listSize = fixedSizeList->list_size ();
443+ typeSize = physicalField->type ()->field (0 )->type ()->byte_width ();
444+ } else if (mapping.vlaIdx != -1 ) {
445+ typeSize = physicalField->type ()->field (0 )->type ()->byte_width ();
446+ listSize = -1 ;
447+ }
448+
401449 while (readEntries < totalEntries) {
402450 auto readLast = branch->GetBulkRead ().GetEntriesSerialized (readEntries, buffer);
403- if (listSize = = -1 ) {
451+ if (mapping. vlaIdx ! = -1 ) {
404452 size = offsets[readEntries + readLast] - offsets[readEntries];
405453 } else {
406454 size = readLast * listSize;
@@ -412,18 +460,15 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
412460 if (listSize >= 1 ) {
413461 totalSize = readEntries * listSize;
414462 }
415- std::shared_ptr<arrow::PrimitiveArray> varray;
416- switch (listSize) {
417- case -1 :
418- varray = std::make_shared<arrow::PrimitiveArray>(physicalField->type ()->field (0 )->type (), totalSize, arrowValuesBuffer);
419- array = std::make_shared<arrow::ListArray>(physicalField->type (), readEntries, arrowOffsetBuffer, varray);
420- break ;
421- case 1 :
422- array = std::make_shared<arrow::PrimitiveArray>(physicalField->type (), readEntries, arrowValuesBuffer);
423- break ;
424- default :
425- varray = std::make_shared<arrow::PrimitiveArray>(physicalField->type ()->field (0 )->type (), totalSize, arrowValuesBuffer);
426- array = std::make_shared<arrow::FixedSizeListArray>(physicalField->type (), readEntries, varray);
463+ if (listSize == 1 ) {
464+ array = std::make_shared<arrow::PrimitiveArray>(datasetField->type (), readEntries, arrowValuesBuffer);
465+ } else {
466+ auto varray = std::make_shared<arrow::PrimitiveArray>(datasetField->type ()->field (0 )->type (), totalSize, arrowValuesBuffer);
467+ if (mapping.vlaIdx != -1 ) {
468+ array = std::make_shared<arrow::ListArray>(datasetField->type (), readEntries, arrowOffsetBuffer, varray);
469+ } else {
470+ array = std::make_shared<arrow::FixedSizeListArray>(datasetField->type (), readEntries, varray);
471+ }
427472 }
428473 }
429474
@@ -534,9 +579,12 @@ auto arrowTypeFromROOT(EDataType type, int size)
534579 }
535580}
536581
582+ // This is a datatype for branches which implies
583+ struct RootTransientIndexType : arrow::ExtensionType {
584+ };
585+
537586arrow::Result<std::shared_ptr<arrow::Schema>> TTreeFileFormat::Inspect (const arrow::dataset::FileSource& source) const
538587{
539- arrow::Schema schema{{}};
540588 auto fs = std::dynamic_pointer_cast<VirtualRootFileSystemBase>(source.filesystem ());
541589 // Actually get the TTree from the ROOT file.
542590 auto treeFs = std::dynamic_pointer_cast<TTreeFileSystem>(fs->GetSubFilesystem (source));
@@ -548,51 +596,37 @@ arrow::Result<std::shared_ptr<arrow::Schema>> TTreeFileFormat::Inspect(const arr
548596 auto branches = tree->GetListOfBranches ();
549597 auto n = branches->GetEntries ();
550598
551- std::vector<BranchInfo> branchInfos;
599+ std::vector<std::shared_ptr<arrow::Field>> fields;
600+
601+ bool prevIsSize = false ;
552602 for (auto i = 0 ; i < n; ++i) {
553603 auto branch = static_cast <TBranch*>(branches->At (i));
554- auto name = std::string{branch->GetName ()};
555- auto pos = name.find (" _size" );
556- if (pos != std::string::npos) {
557- name.erase (pos);
558- branchInfos.emplace_back (BranchInfo{name, (TBranch*)nullptr , true });
604+ std::string name = branch->GetName ();
605+ if (prevIsSize && fields.back ()->name () != name + " _size" ) {
606+ throw runtime_error_f (" Unexpected layout for VLA container %s." , branch->GetName ());
607+ }
608+
609+ if (name.ends_with (" _size" )) {
610+ fields.emplace_back (std::make_shared<arrow::Field>(name, arrow::int32 ()));
611+ prevIsSize = true ;
559612 } else {
560- auto lookup = std::find_if (branchInfos.begin (), branchInfos.end (), [&](BranchInfo const & bi) {
561- return bi.name == name;
562- });
563- if (lookup == branchInfos.end ()) {
564- branchInfos.emplace_back (BranchInfo{name, branch, false });
613+ static TClass* cls;
614+ EDataType type;
615+ branch->GetExpectedType (cls, type);
616+
617+ if (prevIsSize) {
618+ fields.emplace_back (std::make_shared<arrow::Field>(name, arrowTypeFromROOT (type, -1 )));
565619 } else {
566- lookup->ptr = branch;
620+ auto listSize = static_cast <TLeaf*>(branch->GetListOfLeaves ()->At (0 ))->GetLenStatic ();
621+ fields.emplace_back (std::make_shared<arrow::Field>(name, arrowTypeFromROOT (type, listSize)));
567622 }
623+ prevIsSize = false ;
568624 }
569625 }
570626
571- std::vector<std::shared_ptr<arrow::Field>> fields;
572- tree->SetCacheSize (25000000 );
573- for (auto & bi : branchInfos) {
574- static TClass* cls;
575- EDataType type;
576- bi.ptr ->GetExpectedType (cls, type);
577- auto listSize = -1 ;
578- if (!bi.mVLA ) {
579- listSize = static_cast <TLeaf*>(bi.ptr ->GetListOfLeaves ()->At (0 ))->GetLenStatic ();
580- }
581- auto field = std::make_shared<arrow::Field>(bi.ptr ->GetName (), arrowTypeFromROOT (type, listSize));
582- fields.push_back (field);
583-
584- tree->AddBranchToCache (bi.ptr );
585- if (strncmp (bi.ptr ->GetName (), " fIndexArray" , strlen (" fIndexArray" )) == 0 ) {
586- std::string sizeBranchName = bi.ptr ->GetName ();
587- sizeBranchName += " _size" ;
588- auto * sizeBranch = (TBranch*)tree->GetBranch (sizeBranchName.c_str ());
589- if (sizeBranch) {
590- tree->AddBranchToCache (sizeBranch);
591- }
592- }
627+ if (fields.back ()->name ().ends_with (" _size" )) {
628+ throw runtime_error_f (" Missing values for VLA indices %s." , fields.back ()->name ().c_str ());
593629 }
594- tree->StopCacheLearningPhase ();
595-
596630 return std::make_shared<arrow::Schema>(fields);
597631}
598632
@@ -601,9 +635,8 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFragment>> TTreeFileFormat::Ma
601635 arrow::dataset::FileSource source, arrow::compute::Expression partition_expression,
602636 std::shared_ptr<arrow::Schema> physical_schema)
603637{
604- std::shared_ptr<arrow::dataset::FileFormat> format = std::make_shared<TTreeFileFormat>(mTotCompressedSize , mTotUncompressedSize );
605638
606- auto fragment = std::make_shared<TTreeFileFragment>(std::move (source), std::move (format ),
639+ auto fragment = std::make_shared<TTreeFileFragment>(std::move (source), std::dynamic_pointer_cast<arrow::dataset::FileFormat>( shared_from_this () ),
607640 std::move (partition_expression),
608641 std::move (physical_schema));
609642 return std::dynamic_pointer_cast<arrow::dataset::FileFragment>(fragment);
0 commit comments