33#include <vector>
44#include <fstream>
55#include <rapidjson/document.h>
6+ #include "CCDB/BasicCCDBManager.h"
7+ #include "CCDB/CcdbApi.h"
68
79// Static Ort::Env instance for multiple onnx model loading
810static Ort ::Env global_env (ORT_LOGGING_LEVEL_WARNING , "GlobalEnv ");
@@ -377,6 +379,11 @@ class GenTPCLoopers : public Generator
377379} // namespace eventgen
378380} // namespace o2
379381
382+ // ONNX model files can be local, on AliEn or in the ALICE CCDB.
383+ // For local and alien files it is mandatory to provide the filenames, for the CCDB instead the
384+ // path to the object in the CCDB is sufficient. The model files will be downloaded locally.
385+ // Example of CCDB path: "ccdb:Users/n/name/test"
386+ // Example of alien path: "alien:///alice/cern.ch/user/n/name/test/test.onnx"
380387FairGenerator *
381388 Generator_TPCLoopers (std ::string model_pairs = "tpcloopmodel.onnx" , std ::string model_compton = "tpcloopmodelcompton.onnx" ,
382389 std ::string poisson = "poisson.csv" , std ::string gauss = "gauss.csv" , std ::string scaler_pair = "scaler_pair.json" ,
@@ -390,6 +397,49 @@ FairGenerator *
390397 gauss = gSystem -> ExpandPathName (gauss .c_str ());
391398 scaler_pair = gSystem -> ExpandPathName (scaler_pair .c_str ());
392399 scaler_compton = gSystem -> ExpandPathName (scaler_compton .c_str ());
400+ const std ::array < std ::string , 2 > models = {model_pairs , model_compton };
401+ const std ::array < std ::string , 2 > local_names = {"WGANpair.onnx" , "WGANcompton.onnx" };
402+ const std ::array < bool , 2 > isAlien = {models [0 ].starts_with ("alien://" ), models [1 ].starts_with ("alien://" )};
403+ const std ::array < bool , 2 > isCCDB = {models [0 ].starts_with ("ccdb:" ), models [1 ].starts_with ("ccdb:" )};
404+ if (std ::any_of (isAlien .begin (), isAlien .end (), [](bool v ) { return v ; }))
405+ {
406+ TGrid ::Connect ("alien://" );
407+ for (size_t i = 0 ; i < models .size (); ++ i )
408+ {
409+ if (isAlien [i ] && !TFile ::Cp (models [i ].c_str (), local_names [i ].c_str ()))
410+ {
411+ LOG (fatal ) << "Error: Model file " << models [i ] << " does not exist!" ;
412+ exit (1 );
413+ }
414+ }
415+ }
416+ if (std ::any_of (isCCDB .begin (), isCCDB .end (), [](bool v ) { return v ; }))
417+ {
418+ o2 ::ccdb ::CcdbApi ccdb_api ;
419+ ccdb_api .init ("http://alice-ccdb.cern.ch" );
420+ for (size_t i = 0 ; i < models .size (); ++ i )
421+ {
422+ if (isCCDB [i ])
423+ {
424+ auto model_path = models [i ].substr (5 );
425+ // Treat filename is provided in the CCDB path
426+ auto extension = model_path .find (".onnx" );
427+ if (extension != std ::string ::npos )
428+ {
429+ auto last_slash = model_path .find_last_of ('/' );
430+ model_path = model_path .substr (0 , last_slash );
431+ }
432+ std ::map < std ::string , std ::string > filter ;
433+ if (!ccdb_api .retrieveBlob (model_path , "./" , filter , o2 ::ccdb ::getCurrentTimestamp (), false, local_names [i ].c_str ()))
434+ {
435+ LOG (fatal ) << "Error: issues in retrieving " << model_path << " from CCDB!" ;
436+ exit (1 );
437+ }
438+ }
439+ }
440+ }
441+ model_pairs = isAlien [0 ] || isCCDB [0 ] ? local_names [0 ] : model_pairs ;
442+ model_compton = isAlien [1 ] || isCCDB [1 ] ? local_names [1 ] : model_compton ;
393443 auto generator = new o2 ::eventgen ::GenTPCLoopers (model_pairs , model_compton , poisson , gauss , scaler_pair , scaler_compton );
394444 generator -> SetNLoopers (nloopers_pairs , nloopers_compton );
395445 generator -> SetMultiplier (mult );
0 commit comments