33#include <vector>
44#include <fstream>
55#include <rapidjson/document.h>
6+ #include "CCDB/CCDBTimeStampUtils.h"
7+ #include "CCDB/CcdbApi.h"
68
79// Static Ort::Env instance for multiple onnx model loading
810static Ort ::Env global_env (ORT_LOGGING_LEVEL_WARNING , "GlobalEnv ");
@@ -377,6 +379,11 @@ class GenTPCLoopers : public Generator
377379} // namespace eventgen
378380} // namespace o2
379381
382+ // ONNX model files can be local, on AliEn or in the ALICE CCDB.
383+ // For local and alien files it is mandatory to provide the filenames, for the CCDB instead the
384+ // path to the object in the CCDB is sufficient. The model files will be downloaded locally.
385+ // Example of CCDB path: "ccdb://Users/n/name/test"
386+ // Example of alien path: "alien:///alice/cern.ch/user/n/name/test/test.onnx"
380387FairGenerator *
381388 Generator_TPCLoopers (std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
382389 std ::string poisson = "poisson.csv" , std ::string gauss = "gauss.csv" , std ::string scaler_pair = "scaler_pair.json" ,
@@ -390,6 +397,55 @@ FairGenerator *
390397 gauss = gSystem -> ExpandPathName (gauss .c_str ());
391398 scaler_pair = gSystem -> ExpandPathName (scaler_pair .c_str ());
392399 scaler_compton = gSystem -> ExpandPathName (scaler_compton .c_str ());
400+ const std ::array < std ::string , 2 > models = {model_pairs , model_compton };
401+ const std ::array < std ::string , 2 > local_names = {"WGANpair.onnx" , "WGANcompton.onnx" };
402+ const std ::array < bool , 2 > isAlien = {models [0 ].starts_with ("alien://" ), models [1 ].starts_with ("alien://" )};
403+ const std ::array < bool , 2 > isCCDB = {models [0 ].starts_with ("ccdb://" ), models [1 ].starts_with ("ccdb://" )};
404+ if (std ::any_of (isAlien .begin (), isAlien .end (), [](bool v ) { return v ; }))
405+ {
406+ if (!gGrid ) {
407+ TGrid ::Connect ("alien://" );
408+ if (!gGrid ) {
409+ LOG (fatal ) << "AliEn connection failed, check token." ;
410+ exit (1 );
411+ }
412+ }
413+ for (size_t i = 0 ; i < models .size (); ++ i )
414+ {
415+ if (isAlien [i ] && !TFile ::Cp (models [i ].c_str (), local_names [i ].c_str ()))
416+ {
417+ LOG (fatal ) << "Error: Model file " << models [i ] << " does not exist!" ;
418+ exit (1 );
419+ }
420+ }
421+ }
422+ if (std ::any_of (isCCDB .begin (), isCCDB .end (), [](bool v ) { return v ; }))
423+ {
424+ o2 ::ccdb ::CcdbApi ccdb_api ;
425+ ccdb_api .init ("http://alice-ccdb.cern.ch" );
426+ for (size_t i = 0 ; i < models .size (); ++ i )
427+ {
428+ if (isCCDB [i ])
429+ {
430+ auto model_path = models [i ].substr (7 ); // Remove "ccdb://"
431+ // Treat filename if provided in the CCDB path
432+ auto extension = model_path .find (".onnx" );
433+ if (extension != std ::string ::npos )
434+ {
435+ auto last_slash = model_path .find_last_of ('/' );
436+ model_path = model_path .substr (0 , last_slash );
437+ }
438+ std ::map < std ::string , std ::string > filter ;
439+ if (!ccdb_api .retrieveBlob (model_path , "./" , filter , o2 ::ccdb ::getCurrentTimestamp (), false, local_names [i ].c_str ()))
440+ {
441+ LOG (fatal ) << "Error: issues in retrieving " << model_path << " from CCDB!" ;
442+ exit (1 );
443+ }
444+ }
445+ }
446+ }
447+ model_pairs = isAlien [0 ] || isCCDB [0 ] ? local_names [0 ] : model_pairs ;
448+ model_compton = isAlien [1 ] || isCCDB [1 ] ? local_names [1 ] : model_compton ;
393449 auto generator = new o2 ::eventgen ::GenTPCLoopers (model_pairs , model_compton , poisson , gauss , scaler_pair , scaler_compton );
394450 generator -> SetNLoopers (nloopers_pairs , nloopers_compton );
395451 generator -> SetMultiplier (mult );
0 commit comments