4040#endif
4141
4242#ifdef GPUCA_HAS_ONNX
43+ #include < CommonUtils/StringUtils.h>
4344#include " GPUTPCNNClusterizerKernels.h"
4445#include " GPUTPCNNClusterizerHost.h"
4546#endif
@@ -612,7 +613,7 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
612613 }
613614
614615#ifdef GPUCA_HAS_ONNX
615- const GPUSettingsProcessingNNclusterizer& nn_settings = GetProcessingSettings ().nn ;
616+ GPUSettingsProcessingNNclusterizer nn_settings = GetProcessingSettings ().nn ;
616617 GPUTPCNNClusterizerHost nnApplication; // potentially this needs to be GPUTPCNNClusterizerHost nnApplication[NSECTORS]; Technically ONNX ->Run() is threadsafe at inference time since its read-only
617618 if (GetProcessingSettings ().nn .applyNNclusterizer ) {
618619 if (nn_settings.nnLoadFromCCDB ) {
@@ -626,17 +627,35 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
626627 {" nnCCDBInteractionRate" , std::to_string (nn_settings.nnCCDBInteractionRate )}
627628 };
628629
630+ std::string nnFetchFolder = " " ;
631+ std::vector<std::string> fetchMode = o2::utils::Str::tokenize (nn_settings.nnCCDBFetchMode , ' :' );
629632 std::map<std::string, std::string> networkRetrieval = ccdbSettings;
630633
631- networkRetrieval[" nnCCDBLayerType" ] = nn_settings.nnCCDBClassificationLayerType ;
632- networkRetrieval[" nnCCDBEvalType" ] = " classification_c1" ;
633- networkRetrieval[" outputFile" ] = " net_classification_c1.onnx" ;
634- nnApplication.loadFromCCDB (networkRetrieval);
634+ if (fetchMode[0 ] == " c1" ) {
635+ networkRetrieval[" nnCCDBLayerType" ] = nn_settings.nnCCDBClassificationLayerType ;
636+ networkRetrieval[" nnCCDBEvalType" ] = " classification_c1" ;
637+ networkRetrieval[" outputFile" ] = nnFetchFolder + " net_classification_c1.onnx" ;
638+ nnApplication.loadFromCCDB (networkRetrieval);
639+ } else if (fetchMode[0 ] == " c2" ) {
640+ networkRetrieval[" nnCCDBLayerType" ] = nn_settings.nnCCDBClassificationLayerType ;
641+ networkRetrieval[" nnCCDBEvalType" ] = " classification_c2" ;
642+ networkRetrieval[" outputFile" ] = nnFetchFolder + " net_classification_c2.onnx" ;
643+ nnApplication.loadFromCCDB (networkRetrieval);
644+ }
645+ nn_settings.nnClassificationPath = networkRetrieval[" outputFile" ]; // Setting the proper path from the where the models will be initialized locally
635646
636647 networkRetrieval[" nnCCDBLayerType" ] = nn_settings.nnCCDBRegressionLayerType ;
637648 networkRetrieval[" nnCCDBEvalType" ] = " regression_c1" ;
638- networkRetrieval[" outputFile" ] = " net_regression_c1.onnx" ;
649+ networkRetrieval[" outputFile" ] = nnFetchFolder + " net_regression_c1.onnx" ;
639650 nnApplication.loadFromCCDB (networkRetrieval);
651+ nn_settings.nnRegressionPath = networkRetrieval[" outputFile" ];
652+ if (fetchMode[1 ] == " r2" ) {
653+ networkRetrieval[" nnCCDBLayerType" ] = nn_settings.nnCCDBRegressionLayerType ;
654+ networkRetrieval[" nnCCDBEvalType" ] = " regression_c2" ;
655+ networkRetrieval[" outputFile" ] = nnFetchFolder + " net_regression_c2.onnx" ;
656+ nnApplication.loadFromCCDB (networkRetrieval);
657+ nn_settings.nnRegressionPath += " :" , networkRetrieval[" outputFile" ];
658+ }
640659 }
641660
642661 uint32_t maxClusters = 0 ;
0 commit comments