Skip to content

Commit a93770a

Browse files
committed
improve schema serialization
1 parent 2a44a1e commit a93770a

File tree

8 files changed

+171
-14
lines changed

8 files changed

+171
-14
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ struct TableMetadata {
256256

257257
static std::shared_ptr<arrow::Schema> getSchema()
258258
{
259-
return std::make_shared<arrow::Schema>([]<typename... C>(framework::pack<C...>&& p){ return o2::soa::createFieldsFromColumns(p); }(columns{}));
259+
return std::make_shared<arrow::Schema>([]<typename... C>(framework::pack<C...>&& p){ return o2::soa::createFieldsFromColumns(p); }(persistent_columns_t{}));
260260
}
261261
};
262262

@@ -690,7 +690,7 @@ struct Column {
690690

691691
static auto asArrowField()
692692
{
693-
return std::make_shared<arrow::Field>(inherited_t::mLabel, framework::expressions::concreteArrowType(framework::expressions::selectArrowType<type>()));
693+
return std::make_shared<arrow::Field>(inherited_t::mLabel, soa::asArrowDataType<type>());
694694
}
695695

696696
/// FIXME: rather than keeping this public we should have a protected

Framework/Core/include/Framework/AnalysisHelpers.h

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
namespace o2::framework
3030
{
3131
std::string serializeProjectors(std::vector<framework::expressions::Projector>& projectors);
32-
std::string serializeSchema(std::shared_ptr<arrow::Schema>& schema);
32+
std::string serializeSchema(std::shared_ptr<arrow::Schema> schema);
3333
} // namespace o2::framework
3434

3535
namespace o2::soa
@@ -44,6 +44,16 @@ constexpr auto tableRef2ConfigParamSpec()
4444
{"\"\""}};
4545
}
4646

47+
template <TableRef R>
48+
constexpr auto tableRef2Schema()
49+
{
50+
return o2::framework::ConfigParamSpec{
51+
std::string{"input-schema:"} + o2::aod::label<R>(),
52+
framework::VariantType::String,
53+
framework::serializeSchema(o2::aod::MetadataTrait<o2::aod::Hash<R.desc_hash>>::metadata::getSchema()),
54+
{"\"\""}};
55+
}
56+
4757
namespace
4858
{
4959
template <soa::with_sources T>
@@ -56,6 +66,16 @@ inline constexpr auto getSources()
5666
}.template operator()<T::sources.size(), T::sources>();
5767
}
5868

69+
template <soa::with_sources T>
70+
inline constexpr auto getSourceSchemas()
71+
{
72+
return []<size_t N, std::array<soa::TableRef, N> refs>() {
73+
return []<size_t... Is>(std::index_sequence<Is...>) {
74+
return std::vector{soa::tableRef2Schema<refs[Is]>()...};
75+
}(std::make_index_sequence<N>());
76+
}.template operator()<T::sources.size(), T::sources>();
77+
}
78+
5979
template <soa::with_ccdb_urls T>
6080
inline constexpr auto getCCDBUrls()
6181
{
@@ -73,11 +93,19 @@ template <soa::with_sources T>
7393
constexpr auto getInputMetadata() -> std::vector<framework::ConfigParamSpec>
7494
{
7595
std::vector<framework::ConfigParamSpec> inputMetadata;
96+
7697
auto inputSources = getSources<T>();
7798
std::sort(inputSources.begin(), inputSources.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name < b.name; });
7899
auto last = std::unique(inputSources.begin(), inputSources.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name == b.name; });
79100
inputSources.erase(last, inputSources.end());
80101
inputMetadata.insert(inputMetadata.end(), inputSources.begin(), inputSources.end());
102+
103+
auto inputSchemas = getSourceSchemas<T>();
104+
std::sort(inputSchemas.begin(), inputSchemas.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name < b.name; });
105+
last = std::unique(inputSchemas.begin(), inputSchemas.end(), [](framework::ConfigParamSpec const& a, framework::ConfigParamSpec const& b) { return a.name == b.name; });
106+
inputSchemas.erase(last, inputSchemas.end());
107+
inputMetadata.insert(inputMetadata.end(), inputSchemas.begin(), inputSchemas.end());
108+
81109
return inputMetadata;
82110
}
83111

@@ -115,11 +143,8 @@ constexpr auto getExpressionMetadata() -> std::vector<framework::ConfigParamSpec
115143
return result;
116144
}(expression_pack_t{});
117145

118-
auto schema = std::make_shared<arrow::Schema>(o2::soa::createFieldsFromColumns(expression_pack_t{}));
119-
120146
auto json = framework::serializeProjectors(projectors);
121-
return {framework::ConfigParamSpec{"projectors", framework::VariantType::String, json, {"\"\""}},
122-
framework::ConfigParamSpec{"schema", framework::VariantType::String, framework::serializeSchema(schema), {"\"\""}}};
147+
return {framework::ConfigParamSpec{"projectors", framework::VariantType::String, json, {"\"\""}}};
123148
}
124149

125150
template <typename T>
@@ -141,6 +166,9 @@ constexpr auto tableRef2InputSpec()
141166
metadata.insert(metadata.end(), ccdbMetadata.begin(), ccdbMetadata.end());
142167
auto p = getExpressionMetadata<typename o2::aod::MetadataTrait<o2::aod::Hash<R.desc_hash>>::metadata>();
143168
metadata.insert(metadata.end(), p.begin(), p.end());
169+
if constexpr(!soa::with_ccdb_urls<typename o2::aod::MetadataTrait<o2::aod::Hash<R.desc_hash>>::metadata>) {
170+
metadata.emplace_back(framework::ConfigParamSpec{"schema", framework::VariantType::String, framework::serializeSchema(o2::aod::MetadataTrait<o2::aod::Hash<R.desc_hash>>::metadata::getSchema()), {"\"\""}});
171+
}
144172

145173
return framework::InputSpec{
146174
o2::aod::label<R>(),

Framework/Core/include/Framework/ArrowTypes.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#ifndef O2_FRAMEWORK_ARROWTYPES_H
1313
#define O2_FRAMEWORK_ARROWTYPES_H
14+
#include "Framework/Traits.h"
1415
#include "arrow/type_fwd.h"
1516
#include <span>
1617

@@ -117,5 +118,54 @@ template <typename T>
117118
using arrow_array_for_t = typename arrow_array_for<T>::type;
118119
template <typename T>
119120
using value_for_t = typename arrow_array_for<T>::value_type;
121+
122+
template <class Array>
123+
using array_element_t = std::decay_t<decltype(std::declval<Array>()[0])>;
124+
125+
template <typename T>
126+
std::shared_ptr<arrow::DataType> asArrowDataType(int list_size = 1)
127+
{
128+
auto typeGenerator = [](std::shared_ptr<arrow::DataType> const& type, int list_size) -> std::shared_ptr<arrow::DataType> {
129+
switch (list_size) {
130+
case -1:
131+
return arrow::list(type);
132+
case 1:
133+
return std::move(type);
134+
default:
135+
return arrow::fixed_size_list(type, list_size);
136+
}
137+
};
138+
139+
if constexpr (std::is_arithmetic_v<T>) {
140+
if constexpr (std::same_as<T, bool>) {
141+
return typeGenerator(arrow::boolean(), list_size);
142+
} else if constexpr (std::same_as<T, uint8_t>) {
143+
return typeGenerator(arrow::uint8(), list_size);
144+
} else if constexpr (std::same_as<T, uint16_t>) {
145+
return typeGenerator(arrow::uint16(), list_size);
146+
} else if constexpr (std::same_as<T, uint32_t>) {
147+
return typeGenerator(arrow::uint32(), list_size);
148+
} else if constexpr (std::same_as<T, uint64_t>) {
149+
return typeGenerator(arrow::uint64(), list_size);
150+
} else if constexpr (std::same_as<T, int8_t>) {
151+
return typeGenerator(arrow::int8(), list_size);
152+
} else if constexpr (std::same_as<T, int16_t>) {
153+
return typeGenerator(arrow::int16(), list_size);
154+
} else if constexpr (std::same_as<T, int32_t>) {
155+
return typeGenerator(arrow::int32(), list_size);
156+
} else if constexpr (std::same_as<T, int64_t>) {
157+
return typeGenerator(arrow::int64(), list_size);
158+
} else if constexpr (std::same_as<T, float>) {
159+
return typeGenerator(arrow::float32(), list_size);
160+
} else if constexpr (std::same_as<T, double>) {
161+
return typeGenerator(arrow::float64(), list_size);
162+
}
163+
} else if constexpr (std::is_bounded_array_v<T>) {
164+
return asArrowDataType<array_element_t<T>>(std::extent_v<T>);
165+
} else if constexpr (o2::framework::is_specialization_v<T, std::vector>) {
166+
return asArrowDataType<typename T::value_type>(-1);
167+
}
168+
return nullptr;
169+
}
120170
} // namespace o2::soa
121171
#endif // O2_FRAMEWORK_ARROWTYPES_H

Framework/Core/src/AODReaderHelpers.cxx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,11 @@ struct Spawnable {
195195

196196
std::vector<std::shared_ptr<arrow::Schema>> schemas;
197197
for (auto& i : spec.metadata) {
198-
if (i.name.starts_with("input:")) {
199-
labels.emplace_back(i.name.substr(6));
198+
if (i.name.starts_with("input-schema:")) {
199+
labels.emplace_back(i.name.substr(13));
200200
iws.clear();
201-
iws.str(i.defaultValue.get<std::string>());
201+
auto json = i.defaultValue.get<std::string>();
202+
iws.str(json);
202203
schemas.emplace_back(ArrowJSONHelpers::read(iws));
203204
}
204205
}

Framework/Core/src/AnalysisHelpers.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ std::string serializeProjectors(std::vector<framework::expressions::Projector>&
3535
return osm.str();
3636
}
3737

38-
std::string serializeSchema(std::shared_ptr<arrow::Schema>& schema)
38+
std::string serializeSchema(std::shared_ptr<arrow::Schema> schema)
3939
{
4040
std::stringstream osm;
4141
ArrowJSONHelpers::write(osm, schema);

Framework/Core/src/AnalysisSupportHelpers.cxx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ void AnalysisSupportHelpers::addMissingOutputsToAnalysisCCDBFetcher(
219219
// FIXME: good enough for now...
220220
for (auto& i : input.metadata) {
221221
if ((i.type == VariantType::String) && (i.name.find("input:") != std::string::npos)) {
222-
auto value = i.defaultValue.get<std::string>();
223222
auto spec = DataSpecUtils::fromMetadataString(i.defaultValue.get<std::string>());
224223
auto j = std::find_if(publisher.inputs.begin(), publisher.inputs.end(), [&](auto x) { return x.binding == spec.binding; });
225224
if (j == publisher.inputs.end()) {

Framework/Core/src/ExpressionJSONHelpers.cxx

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,18 @@ void o2::framework::ExpressionJSONHelpers::write(std::ostream& o, std::vector<o2
637637

638638
namespace
639639
{
640+
std::shared_ptr<arrow::DataType> arrowDataTypeFromId(atype::type type, int list_size = 1, atype::type element = atype::NA)
641+
{
642+
switch (list_size) {
643+
case -1:
644+
return arrow::list(expressions::concreteArrowType(element));
645+
case 1:
646+
return expressions::concreteArrowType(type);
647+
default:
648+
return arrow::fixed_size_list(expressions::concreteArrowType(element), list_size);
649+
}
650+
}
651+
640652
struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, SchemaReader> {
641653
using Ch = rapidjson::UTF8<>::Ch;
642654
using SizeType = rapidjson::SizeType;
@@ -658,6 +670,8 @@ struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, Sch
658670

659671
std::string name;
660672
atype::type type;
673+
atype::type element;
674+
int list_size = 1;
661675

662676
SchemaReader()
663677
{
@@ -706,6 +720,12 @@ struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, Sch
706720
if (currentKey.compare("type") == 0) {
707721
return true;
708722
}
723+
if (currentKey.compare("size") == 0) {
724+
return true;
725+
}
726+
if (currentKey.compare("element") == 0) {
727+
return true;
728+
}
709729
}
710730

711731
states.push(State::IN_ERROR);
@@ -721,6 +741,9 @@ struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, Sch
721741

722742
if (states.top() == State::IN_LIST) {
723743
states.push(State::IN_FIELD);
744+
list_size = 1;
745+
element = atype::NA;
746+
type = atype::NA;
724747
return true;
725748
}
726749

@@ -734,7 +757,7 @@ struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, Sch
734757
if (states.top() == State::IN_FIELD) {
735758
states.pop();
736759
// add a field
737-
fields.emplace_back(std::make_shared<arrow::Field>(name, expressions::concreteArrowType(type)));
760+
fields.emplace_back(std::make_shared<arrow::Field>(name, arrowDataTypeFromId(type, list_size, element)));
738761
return true;
739762
}
740763

@@ -754,6 +777,14 @@ struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, Sch
754777
type = (atype::type)i;
755778
return true;
756779
}
780+
if (currentKey.compare("element") == 0) {
781+
element = (atype::type)i;
782+
return true;
783+
}
784+
if (currentKey.compare("size") == 0) {
785+
list_size = i;
786+
return true;
787+
}
757788
}
758789

759790
states.push(State::IN_ERROR);
@@ -777,6 +808,10 @@ struct SchemaReader : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, Sch
777808
bool Int(int i)
778809
{
779810
debug << "Int(" << i << ")" << std::endl;
811+
if (states.top() == State::IN_FIELD && currentKey.compare("size") == 0) {
812+
list_size = i;
813+
return true;
814+
}
780815
return Uint(i);
781816
}
782817
};
@@ -791,7 +826,7 @@ std::shared_ptr<arrow::Schema> o2::framework::ArrowJSONHelpers::read(std::istrea
791826
bool ok = reader.Parse(isw, sreader);
792827

793828
if (!ok) {
794-
throw framework::runtime_error_f("Cannot parse serialized Expression, error: %s at offset: %d", rapidjson::GetParseError_En(reader.GetParseErrorCode()), reader.GetErrorOffset());
829+
throw framework::runtime_error_f("Cannot parse serialized Schema, error: %s at offset: %d", rapidjson::GetParseError_En(reader.GetParseErrorCode()), reader.GetErrorOffset());
795830
}
796831
return sreader.schema;
797832
}
@@ -804,6 +839,20 @@ void writeSchema(rapidjson::Writer<rapidjson::OStreamWrapper>& w, arrow::Schema*
804839
w.StartObject();
805840
w.Key("name");
806841
w.String(f->name().c_str());
842+
auto fixedList = dynamic_cast<arrow::FixedSizeListType*>(f->type().get());
843+
if (fixedList != nullptr) {
844+
w.Key("size");
845+
w.Int(fixedList->list_size());
846+
w.Key("element");
847+
w.Int(fixedList->field(0)->type()->id());
848+
}
849+
auto varList = dynamic_cast<arrow::ListType*>(f->type().get());
850+
if (varList != nullptr) {
851+
w.Key("size");
852+
w.Int(-1);
853+
w.Key("element");
854+
w.Int(varList->field(0)->type()->id());
855+
}
807856
w.Key("type");
808857
w.Int(f->type()->id());
809858
w.EndObject();

Framework/Core/test/test_Expressions.cxx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,4 +454,34 @@ TEST_CASE("TestExpressionSerialization")
454454
ism.str(osm.str());
455455
auto newSchemap = ArrowJSONHelpers::read(ism);
456456
REQUIRE(schemap->ToString() == newSchemap->ToString());
457+
458+
osm.clear();
459+
osm.str("");
460+
ArrowJSONHelpers::write(osm, schemap1);
461+
462+
ism.clear();
463+
ism.str(osm.str());
464+
auto newSchemap1 = ArrowJSONHelpers::read(ism);
465+
REQUIRE(schemap1->ToString() == newSchemap1->ToString());
466+
467+
osm.clear();
468+
osm.str("");
469+
auto realisticSchema = std::make_shared<arrow::Schema>(o2::soa::createFieldsFromColumns(o2::aod::MetadataTrait<o2::aod::Hash<"HMPID/1"_h>>::metadata::persistent_columns_t{}));
470+
ArrowJSONHelpers::write(osm, realisticSchema);
471+
472+
ism.clear();
473+
ism.str(osm.str());
474+
auto restoredSchema = ArrowJSONHelpers::read(ism);
475+
REQUIRE(realisticSchema->ToString() == restoredSchema->ToString());
476+
477+
osm.clear();
478+
osm.str("");
479+
auto realisticSchema1 = std::make_shared<arrow::Schema>(o2::soa::createFieldsFromColumns(o2::aod::MetadataTrait<o2::aod::Hash<"ZDC/1"_h>>::metadata::persistent_columns_t{}));
480+
ArrowJSONHelpers::write(osm, realisticSchema1);
481+
482+
ism.clear();
483+
ism.str(osm.str());
484+
auto restoredSchema1 = ArrowJSONHelpers::read(ism);
485+
REQUIRE(realisticSchema1->ToString() == restoredSchema1->ToString());
486+
457487
}

0 commit comments

Comments
 (0)