Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Framework/Core/src/RootArrowFilesystem.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Framework/RuntimeError.h"
#include "Framework/Signpost.h"
#include <Rtypes.h>
#include <arrow/array/array_nested.h>
#include <arrow/array/array_primitive.h>
#include <arrow/array/builder_nested.h>
#include <arrow/array/builder_primitive.h>
Expand Down Expand Up @@ -427,7 +428,7 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.",
branch->GetName(), valueSize);
// This should probably lookup the
auto column = firstBatch->GetColumnByName(branch->GetName());
auto column = firstBatch->GetColumnByName(schema_->field(i)->name());
auto list = std::static_pointer_cast<arrow::ListArray>(column);
O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.",
branch->GetName(), sizeBranch->GetName(), list->length(), valueSize);
Expand Down Expand Up @@ -497,8 +498,8 @@ class TTreeFileWriter : public arrow::dataset::FileWriter
} break;
case arrow::Type::LIST: {
valueTypes.push_back(field->type()->field(0)->type());
listSizes.back() = 0; // VLA, we need to calculate it on the fly;
std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id()));
listSizes.back() = -1; // VLA, we need to calculate it on the fly;
std::string sizeLeafList = field->name() + "_size/I";
sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str()));
branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str()));
Expand Down Expand Up @@ -765,7 +766,7 @@ arrow::Result<arrow::RecordBatchGenerator> TTreeFileFormat::ScanBatchesAsync(
typeSize = fixedSizeList->field(0)->type()->byte_width();
} else if (auto vlaListType = std::dynamic_pointer_cast<arrow::ListType>(physicalField->type())) {
listSize = -1;
typeSize = fixedSizeList->field(0)->type()->byte_width();
typeSize = vlaListType->field(0)->type()->byte_width();
}
if (listSize == -1) {
mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str());
Expand Down
120 changes: 50 additions & 70 deletions Framework/Core/test/test_Root2ArrowTable.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,26 @@ bool validateContents(std::shared_ptr<arrow::RecordBatch> batch)
REQUIRE(bool_array->Value(1) == (i % 5 == 0));
}
}

{
auto list_array = std::static_pointer_cast<arrow::ListArray>(batch->GetColumnByName("vla"));

REQUIRE(list_array->length() == 100);
for (int64_t i = 0; i < list_array->length(); i++) {
auto value_slice = list_array->value_slice(i);
REQUIRE(value_slice->length() == (i % 10));
auto int_array = std::static_pointer_cast<arrow::Int32Array>(value_slice);
for (size_t j = 0; j < value_slice->length(); j++) {
REQUIRE(int_array->Value(j) == j);
}
}
}
return true;
}

bool validateSchema(std::shared_ptr<arrow::Schema> schema)
{
REQUIRE(schema->num_fields() == 9);
REQUIRE(schema->num_fields() == 10);
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
Expand All @@ -337,6 +351,7 @@ bool validateSchema(std::shared_ptr<arrow::Schema> schema)
REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id());
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
REQUIRE(schema->field(9)->type()->id() == arrow::list(arrow::int32())->id());
return true;
}

Expand Down Expand Up @@ -390,6 +405,8 @@ TEST_CASE("RootTree2Dataset")
Int_t ev;
bool oneBool;
bool manyBool[2];
int vla[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int vlaSize = 0;

t->Branch("px", &px, "px/F");
t->Branch("py", &py, "py/F");
Expand All @@ -400,6 +417,8 @@ TEST_CASE("RootTree2Dataset")
t->Branch("ij", ij, "ij[2]/I");
t->Branch("bools", &oneBool, "bools/O");
t->Branch("manyBools", &manyBool, "manyBools[2]/O");
t->Branch("vla_size", &vlaSize, "vla_size/I");
t->Branch("vla", vla, "vla[vla_size]/I");
// fill the tree
for (Int_t i = 0; i < 100; i++) {
xyz[0] = 1;
Expand All @@ -415,9 +434,11 @@ TEST_CASE("RootTree2Dataset")
oneBool = (i % 3 == 0);
manyBool[0] = (i % 4 == 0);
manyBool[1] = (i % 5 == 0);
vlaSize = i % 10;
t->Fill();
}
}
f->Write();

size_t totalSizeCompressed = 0;
size_t totalSizeUncompressed = 0;
Expand All @@ -428,16 +449,7 @@ TEST_CASE("RootTree2Dataset")
auto schemaOpt = format->Inspect(source);
REQUIRE(schemaOpt.ok());
auto schema = *schemaOpt;
REQUIRE(schema->num_fields() == 9);
REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id());
REQUIRE(schema->field(3)->type()->id() == arrow::float64()->id());
REQUIRE(schema->field(4)->type()->id() == arrow::int32()->id());
REQUIRE(schema->field(5)->type()->id() == arrow::fixed_size_list(arrow::float32(), 3)->id());
REQUIRE(schema->field(6)->type()->id() == arrow::fixed_size_list(arrow::int32(), 2)->id());
REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id());
REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id());
validateSchema(schema);

auto fragment = format->MakeFragment(source, {}, schema);
REQUIRE(fragment.ok());
Expand All @@ -448,41 +460,9 @@ TEST_CASE("RootTree2Dataset")
auto batches = (*scanner)();
auto result = batches.result();
REQUIRE(result.ok());
REQUIRE((*result)->columns().size() == 9);
REQUIRE((*result)->columns().size() == 10);
REQUIRE((*result)->num_rows() == 100);

{
auto int_array = std::static_pointer_cast<arrow::Int32Array>((*result)->GetColumnByName("ev"));
for (int64_t j = 0; j < int_array->length(); j++) {
REQUIRE(int_array->Value(j) == j + 1);
}
}

{
auto list_array = std::static_pointer_cast<arrow::FixedSizeListArray>((*result)->GetColumnByName("xyz"));

// Iterate over the FixedSizeListArray
for (int64_t i = 0; i < list_array->length(); i++) {
auto value_slice = list_array->value_slice(i);
auto float_array = std::static_pointer_cast<arrow::FloatArray>(value_slice);

REQUIRE(float_array->Value(0) == 1);
REQUIRE(float_array->Value(1) == 2);
REQUIRE(float_array->Value(2) == i + 1);
}
}

{
auto list_array = std::static_pointer_cast<arrow::FixedSizeListArray>((*result)->GetColumnByName("ij"));

// Iterate over the FixedSizeListArray
for (int64_t i = 0; i < list_array->length(); i++) {
auto value_slice = list_array->value_slice(i);
auto int_array = std::static_pointer_cast<arrow::Int32Array>(value_slice);
REQUIRE(int_array->Value(0) == i);
REQUIRE(int_array->Value(1) == i + 1);
}
}
validateContents(*result);

auto* output = new TMemFile("foo", "RECREATE");
auto outFs = std::make_shared<TFileFileSystem>(output, 0);
Expand All @@ -497,31 +477,31 @@ TEST_CASE("RootTree2Dataset")
auto success = writer->get()->Write(*result);
auto rootDestination = std::dynamic_pointer_cast<TDirectoryFileOutputStream>(*destination);

REQUIRE(success.ok());
// Let's read it back...
arrow::dataset::FileSource source2("/DF_3", outFs);
auto newTreeFS = outFs->GetSubFilesystem(source2);

REQUIRE(format->IsSupported(source) == true);

auto schemaOptWritten = format->Inspect(source);
REQUIRE(schemaOptWritten.ok());
auto schemaWritten = *schemaOptWritten;
REQUIRE(validateSchema(schemaWritten));

auto fragmentWritten = format->MakeFragment(source, {}, schema);
REQUIRE(fragmentWritten.ok());
auto optionsWritten = std::make_shared<arrow::dataset::ScanOptions>();
options->dataset_schema = schemaWritten;
auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment);
REQUIRE(scannerWritten.ok());
auto batchesWritten = (*scanner)();
auto resultWritten = batches.result();
REQUIRE(resultWritten.ok());
REQUIRE((*resultWritten)->columns().size() == 9);
REQUIRE((*resultWritten)->num_rows() == 100);
validateContents(*resultWritten);

SECTION("Read tree")
{
REQUIRE(success.ok());
// Let's read it back...
arrow::dataset::FileSource source2("/DF_3", outFs);
auto newTreeFS = outFs->GetSubFilesystem(source2);

REQUIRE(format->IsSupported(source) == true);

auto schemaOptWritten = format->Inspect(source);
REQUIRE(schemaOptWritten.ok());
auto schemaWritten = *schemaOptWritten;
REQUIRE(validateSchema(schemaWritten));

auto fragmentWritten = format->MakeFragment(source, {}, schema);
REQUIRE(fragmentWritten.ok());
auto optionsWritten = std::make_shared<arrow::dataset::ScanOptions>();
options->dataset_schema = schemaWritten;
auto scannerWritten = format->ScanBatchesAsync(optionsWritten, *fragment);
REQUIRE(scannerWritten.ok());
auto batchesWritten = (*scanner)();
auto resultWritten = batches.result();
REQUIRE(resultWritten.ok());
REQUIRE((*resultWritten)->columns().size() == 10);
REQUIRE((*resultWritten)->num_rows() == 100);
validateContents(*resultWritten);
}
}