Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions Tools/ML/MlResponse.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ class MlResponse
mPaths = std::vector<std::string>(mNModels);
}

/// Configure class instance (import configurables)
/// \param binsLimitsVar1 is a vector containing bins limits for a first variable
/// \param binsLimitsVar2 is a vector containing bins limits for a second variable
/// \param cuts is a LabeledArray containing selections per bin
/// \param cutDir is a vector telling whether to reject score values greater or smaller than the threshold
/// \param nClasses is the number of classes for each model
void configure2D(const std::vector<double>& binsLimitsVar1, const std::vector<double>& binsLimitsVar2, const o2::framework::LabeledArray<double>& cuts, const std::vector<int>& cutDir, const uint8_t& nClasses)
{
if (cutDir.size() != nClasses) {
LOG(fatal) << "Mismatch between nClasses and cutDir size";
}

mBinsLimits = binsLimitsVar1;
mBinsLimitsVar2 = binsLimitsVar2;
mCuts = cuts;
mCutDir = cutDir;
mNClasses = nClasses;

mNVar1Bins = binsLimitsVar1.size() - 1;
mNVar2Bins = binsLimitsVar2.size() - 1;
mNModels = mNVar1Bins * mNVar2Bins;
mModels = std::vector<o2::ml::OnnxModel>(mNModels);
mPaths = std::vector<std::string>(mNModels);

mUse2DBinning = true;
}

/// Set model paths to CCDB
/// \param onnxFiles is a vector of onnx file names, one for each bin
/// \param ccdbApi is the CCDB API
Expand Down Expand Up @@ -210,16 +237,50 @@ class MlResponse
return true;
}

/// ML selections
/// \param input is the input features
/// \param candVar1 is the first variable value (e.g. pT) used to select which model to use
/// \param candVar2 is the second variable value (e.g. multiplicity) used to select which model to use
/// \param output is a container to be filled with model output
/// \return boolean telling if model predictions pass the cuts
template <typename T1, typename T2, typename T3>
bool isSelectedMl(T1& input, const T2& candVar1, const T3& candVar2, std::vector<TypeOutputScore>& output)
{
if (!mUse2DBinning) {
LOG(fatal) << "2D ML selection called on a class not configured for 2D bins";
}
int nModel = findBin2D(candVar1, candVar2);
output = getModelOutput(input, nModel);
uint8_t iClass{0};
for (const auto& outputValue : output) {
uint8_t dir = mCutDir.at(iClass);
if (dir != o2::cuts_ml::CutDirection::CutNot) {
if (dir == o2::cuts_ml::CutDirection::CutGreater && outputValue > mCuts.get(nModel, iClass)) {
return false;
}
if (dir == o2::cuts_ml::CutDirection::CutSmaller && outputValue < mCuts.get(nModel, iClass)) {
return false;
}
}
++iClass;
}
return true;
}

protected:
std::vector<o2::ml::OnnxModel> mModels; // OnnxModel objects, one for each bin
uint8_t mNModels = 1; // number of bins
uint8_t mNClasses = 3; // number of model classes
std::vector<double> mBinsLimits = {}; // bin limits of the variable (e.g. pT) used to select which model to use
std::vector<double> mBinsLimitsVar2 = {}; // bin limits of a second variable (e.g. multiplicity) used to select which model to use (not used in this base class)
std::vector<std::string> mPaths = {""}; // paths to the models, one for each bin
std::vector<int> mCutDir = {}; // direction of the cuts on the model scores (no cut is also supported)
o2::framework::LabeledArray<double> mCuts = {}; // array of cut values to apply on the model scores
std::map<std::string, uint8_t> mAvailableInputFeatures; // map of available input features
std::vector<uint8_t> mCachedIndices; // vector of index correspondance between configurables and available input features
uint8_t mNVar1Bins = 1; // number of bins of the first variable (e.g. pT) used to select which model to use
uint8_t mNVar2Bins = 1; // number of bins of the second variable (e.g. multiplicity) used to select which model to use
bool mUse2DBinning = false; // switch to enable/disable 2D binning

virtual void setAvailableInputFeatures() { return; } // method to fill the map of available input features

Expand All @@ -239,6 +300,34 @@ class MlResponse
}
return std::distance(mBinsLimits.begin(), std::upper_bound(mBinsLimits.begin(), mBinsLimits.end(), value)) - 1;
}

/// Finds matching bin in mBinsLimits
/// \param value1 e.g. pT
/// \param value2 e.g. multiplicity
/// \return index of the matching bin, used to access mModels
/// \note Accounts for the offset due to mBinsLimits storing bin limits (same convention as needed to configure a histogram axis)
template <typename T1, typename T2>
int findBin2D(T1 const& value1, T2 const& value2)
{
if (!mUse2DBinning) {
LOG(fatal) << "2D ML selection called on a class not configured for 2D bins";
}
if (value1 < mBinsLimits.front()) {
return -1;
}
if (value1 >= mBinsLimits.back()) {
return -1;
}
if (value2 < mBinsLimitsVar2.front()) {
return -1;
}
if (value2 >= mBinsLimitsVar2.back()) {
return -1;
}
int bin1 = std::distance(mBinsLimits.begin(), std::upper_bound(mBinsLimits.begin(), mBinsLimits.end(), value1)) - 1;
int bin2 = std::distance(mBinsLimitsVar2.begin(), std::upper_bound(mBinsLimitsVar2.begin(), mBinsLimitsVar2.end(), value2)) - 1;
return bin2 * mNVar1Bins + bin1;
}
};

} // namespace analysis
Expand Down
Loading