Skip to content

Commit 950b8b7

Browse files
authored
DPL Analysis: improve arrow::Dataset support for TTree (#13759)
1 parent feea3ad commit 950b8b7

File tree

2 files changed

+54
-73
lines changed

2 files changed

+54
-73
lines changed

Framework/Core/src/RootArrowFilesystem.cxx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "Framework/RuntimeError.h"
1414
#include "Framework/Signpost.h"
1515
#include <Rtypes.h>
16+
#include <arrow/array/array_nested.h>
1617
#include <arrow/array/array_primitive.h>
1718
#include <arrow/array/builder_nested.h>
1819
#include <arrow/array/builder_primitive.h>
@@ -427,7 +428,7 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
427428
O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.",
428429
branch->GetName(), valueSize);
429430
// This should probably lookup the
430-
auto column = firstBatch->GetColumnByName(branch->GetName());
431+
auto column = firstBatch->GetColumnByName(schema_->field(i)->name());
431432
auto list = std::static_pointer_cast<arrow::ListArray>(column);
432433
O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.",
433434
branch->GetName(), sizeBranch->GetName(), list->length(), valueSize);
@@ -497,8 +498,8 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
497498
} break;
498499
case arrow::Type::LIST: {
499500
valueTypes.push_back(field->type()->field(0)->type());
500-
listSizes.back() = 0; // VLA, we need to calculate it on the fly;
501501
std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id()));
502+
listSizes.back() = -1; // VLA, we need to calculate it on the fly;
502503
std::string sizeLeafList = field->name() + "_size/I";
503504
sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str()));
504505
branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str()));
@@ -765,7 +766,7 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
765766
typeSize = fixedSizeList->field(0)->type()->byte_width();
766767
} else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type())) {
767768
listSize = -1;
768-
typeSize = fixedSizeList->field(0)->type()->byte_width();
769+
typeSize = vlaListType->field(0)->type()->byte_width();
769770
}
770771
if (listSize == -1) {
771772
mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str());

Framework/Core/test/test_Root2ArrowTable.cxx

Lines changed: 50 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,26 @@ bool validateContents(std::shared_ptr<arrow::RecordBatch> batch)
322322
REQUIRE(bool_array->Value(1) == (i % 5 == 0));
323323
}
324324
}
325+
326+
{
327+
auto list_array = std::static_pointer_cast<arrow::ListArray>(batch->GetColumnByName("vla"));
328+
329+
REQUIRE(list_array->length() == 100);
330+
for (int64_t i = 0; i < list_array->length(); i++) {
331+
auto value_slice = list_array->value_slice(i);
332+
REQUIRE(value_slice->length() == (i % 10));
333+
auto int_array = std::static_pointer_cast<arrow::Int32Array>(value_slice);
334+
for (size_t j = 0; j < value_slice->length(); j++) {
335+
REQUIRE(int_array->Value(j) == j);
336+
}
337+
}
338+
}
325339
return true;
326340
}
327341

328342
bool validateSchema(std::shared_ptr<arrow::Schema> schema)
329343
{
330-
REQUIRE(schema->num_fields() == 9);
344+
REQUIRE(schema->num_fields() == 10);
331345
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
332346
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
333347
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
@@ -337,6 +351,7 @@ bool validateSchema(std::shared_ptr<arrow::Schema> schema)
337351
REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id());
338352
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
339353
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
354+
REQUIRE(schema->field(9)->type()->id() == arrow::list(arrow::int32())->id());
340355
return true;
341356
}
342357

@@ -390,6 +405,8 @@ TEST_CASE("RootTree2Dataset")
390405
Int_t ev;
391406
bool oneBool;
392407
bool manyBool[2];
408+
int vla[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
409+
int vlaSize = 0;
393410

394411
t->Branch("px", &px, "px/F");
395412
t->Branch("py", &py, "py/F");
@@ -400,6 +417,8 @@ TEST_CASE("RootTree2Dataset")
400417
t->Branch("ij", ij, "ij[2]/I");
401418
t->Branch("bools", &oneBool, "bools/O");
402419
t->Branch("manyBools", &manyBool, "manyBools[2]/O");
420+
t->Branch("vla_size", &vlaSize, "vla_size/I");
421+
t->Branch("vla", vla, "vla[vla_size]/I");
403422
// fill the tree
404423
for (Int_t i = 0; i < 100; i++) {
405424
xyz[0] = 1;
@@ -415,9 +434,11 @@ TEST_CASE("RootTree2Dataset")
415434
oneBool = (i % 3 == 0);
416435
manyBool[0] = (i % 4 == 0);
417436
manyBool[1] = (i % 5 == 0);
437+
vlaSize = i % 10;
418438
t->Fill();
419439
}
420440
}
441+
f->Write();
421442

422443
size_t totalSizeCompressed = 0;
423444
size_t totalSizeUncompressed = 0;
@@ -428,16 +449,7 @@ TEST_CASE("RootTree2Dataset")
428449
auto schemaOpt = format->Inspect(source);
429450
REQUIRE(schemaOpt.ok());
430451
auto schema = *schemaOpt;
431-
REQUIRE(schema->num_fields() == 9);
432-
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
433-
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
434-
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
435-
REQUIRE(schema->field(3)->type()->id() == arrow::float64()->id());
436-
REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id());
437-
REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id());
438-
REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id());
439-
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
440-
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
452+
validateSchema(schema);
441453

442454
auto fragment = format->MakeFragment(source, {}, schema);
443455
REQUIRE(fragment.ok());
@@ -448,41 +460,9 @@ TEST_CASE("RootTree2Dataset")
448460
auto batches = (*scanner)();
449461
auto result = batches.result();
450462
REQUIRE(result.ok());
451-
REQUIRE((*result)->columns().size() == 9);
463+
REQUIRE((*result)->columns().size() == 10);
452464
REQUIRE((*result)->num_rows() == 100);
453-
454-
{
455-
auto int_array = std::static_pointer_cast<arrow::Int32Array>((*result)->GetColumnByName("ev"));
456-
for (int64_t j = 0; j < int_array->length(); j++) {
457-
REQUIRE(int_array->Value(j) == j + 1);
458-
}
459-
}
460-
461-
{
462-
auto list_array = std::static_pointer_cast<arrow::FixedSizeListArray>((*result)->GetColumnByName("xyz"));
463-
464-
// Iterate over the FixedSizeListArray
465-
for (int64_t i = 0; i < list_array->length(); i++) {
466-
auto value_slice = list_array->value_slice(i);
467-
auto float_array = std::static_pointer_cast<arrow::FloatArray>(value_slice);
468-
469-
REQUIRE(float_array->Value(0) == 1);
470-
REQUIRE(float_array->Value(1) == 2);
471-
REQUIRE(float_array->Value(2) == i + 1);
472-
}
473-
}
474-
475-
{
476-
auto list_array = std::static_pointer_cast<arrow::FixedSizeListArray>((*result)->GetColumnByName("ij"));
477-
478-
// Iterate over the FixedSizeListArray
479-
for (int64_t i = 0; i < list_array->length(); i++) {
480-
auto value_slice = list_array->value_slice(i);
481-
auto int_array = std::static_pointer_cast<arrow::Int32Array>(value_slice);
482-
REQUIRE(int_array->Value(0) == i);
483-
REQUIRE(int_array->Value(1) == i + 1);
484-
}
485-
}
465+
validateContents(*result);
486466

487467
auto* output = new TMemFile("foo", "RECREATE");
488468
auto outFs = std::make_shared<TFileFileSystem>(output, 0);
@@ -497,31 +477,31 @@ TEST_CASE("RootTree2Dataset")
497477
auto success = writer->get()->Write(*result);
498478
auto rootDestination = std::dynamic_pointer_cast<TDirectoryFileOutputStream>(*destination);
499479

500-
REQUIRE(success.ok());
501-
// Let's read it back...
502-
arrow::dataset::FileSource source2("/DF_3", outFs);
503-
auto newTreeFS = outFs->GetSubFilesystem(source2);
504-
505-
REQUIRE(format->IsSupported(source) == true);
506-
507-
auto schemaOptWritten = format->Inspect(source);
508-
REQUIRE(schemaOptWritten.ok());
509-
auto schemaWritten = *schemaOptWritten;
510-
REQUIRE(validateSchema(schemaWritten));
511-
512-
auto fragmentWritten = format->MakeFragment(source, {}, schema);
513-
REQUIRE(fragmentWritten.ok());
514-
auto optionsWritten = std::make_shared<arrow::dataset::ScanOptions>();
515-
options->dataset_schema = schemaWritten;
516-
auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment);
517-
REQUIRE(scannerWritten.ok());
518-
auto batchesWritten = (*scanner)();
519-
auto resultWritten = batches.result();
520-
REQUIRE(resultWritten.ok());
521-
REQUIRE((*resultWritten)->columns().size() == 9);
522-
REQUIRE((*resultWritten)->num_rows() == 100);
523-
validateContents(*resultWritten);
524-
480+
SECTION("Read tree")
525481
{
482+
REQUIRE(success.ok());
483+
// Let's read it back...
484+
arrow::dataset::FileSource source2("/DF_3", outFs);
485+
auto newTreeFS = outFs->GetSubFilesystem(source2);
486+
487+
REQUIRE(format->IsSupported(source) == true);
488+
489+
auto schemaOptWritten = format->Inspect(source);
490+
REQUIRE(schemaOptWritten.ok());
491+
auto schemaWritten = *schemaOptWritten;
492+
REQUIRE(validateSchema(schemaWritten));
493+
494+
auto fragmentWritten = format->MakeFragment(source, {}, schema);
495+
REQUIRE(fragmentWritten.ok());
496+
auto optionsWritten = std::make_shared<arrow::dataset::ScanOptions>();
497+
options->dataset_schema = schemaWritten;
498+
auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment);
499+
REQUIRE(scannerWritten.ok());
500+
auto batchesWritten = (*scanner)();
501+
auto resultWritten = batches.result();
502+
REQUIRE(resultWritten.ok());
503+
REQUIRE((*resultWritten)->columns().size() == 10);
504+
REQUIRE((*resultWritten)->num_rows() == 100);
505+
validateContents(*resultWritten);
526506
}
527507
}

0 commit comments

Comments
 (0)