Skip to content

Commit 34f73d2

Browse files
ikantakIsabel Kantak
andauthored
[Tools] Add ML model selection depending on two variables (#14238)
Co-authored-by: Isabel Kantak <kantak@physi.uni-heidelberg.de>
1 parent 9e55be6 commit 34f73d2

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

Tools/ML/MlResponse.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,33 @@ class MlResponse
7373
mPaths = std::vector<std::string>(mNModels);
7474
}
7575

76+
/// Configure class instance (import configurables)
77+
/// \param binsLimitsVar1 is a vector containing bins limits for a first variable
78+
/// \param binsLimitsVar2 is a vector containing bins limits for a second variable
79+
/// \param cuts is a LabeledArray containing selections per bin
80+
/// \param cutDir is a vector telling whether to reject score values greater or smaller than the threshold
81+
/// \param nClasses is the number of classes for each model
82+
void configure2D(const std::vector<double>& binsLimitsVar1, const std::vector<double>& binsLimitsVar2, const o2::framework::LabeledArray<double>& cuts, const std::vector<int>& cutDir, const uint8_t& nClasses)
83+
{
84+
if (cutDir.size() != nClasses) {
85+
LOG(fatal) << "Mismatch between nClasses and cutDir size";
86+
}
87+
88+
mBinsLimits = binsLimitsVar1;
89+
mBinsLimitsVar2 = binsLimitsVar2;
90+
mCuts = cuts;
91+
mCutDir = cutDir;
92+
mNClasses = nClasses;
93+
94+
mNVar1Bins = binsLimitsVar1.size() - 1;
95+
mNVar2Bins = binsLimitsVar2.size() - 1;
96+
mNModels = mNVar1Bins * mNVar2Bins;
97+
mModels = std::vector<o2::ml::OnnxModel>(mNModels);
98+
mPaths = std::vector<std::string>(mNModels);
99+
100+
mUse2DBinning = true;
101+
}
102+
76103
/// Set model paths to CCDB
77104
/// \param onnxFiles is a vector of onnx file names, one for each bin
78105
/// \param ccdbApi is the CCDB API
@@ -210,16 +237,50 @@ class MlResponse
210237
return true;
211238
}
212239

240+
/// ML selections
241+
/// \param input is the input features
242+
/// \param candVar1 is the first variable value (e.g. pT) used to select which model to use
243+
/// \param candVar2 is the second variable value (e.g. multiplicity) used to select which model to use
244+
/// \param output is a container to be filled with model output
245+
/// \return boolean telling if model predictions pass the cuts
246+
template <typename T1, typename T2, typename T3>
247+
bool isSelectedMl(T1& input, const T2& candVar1, const T3& candVar2, std::vector<TypeOutputScore>& output)
248+
{
249+
if (!mUse2DBinning) {
250+
LOG(fatal) << "2D ML selection called on a class not configured for 2D bins";
251+
}
252+
int nModel = findBin2D(candVar1, candVar2);
253+
output = getModelOutput(input, nModel);
254+
uint8_t iClass{0};
255+
for (const auto& outputValue : output) {
256+
uint8_t dir = mCutDir.at(iClass);
257+
if (dir != o2::cuts_ml::CutDirection::CutNot) {
258+
if (dir == o2::cuts_ml::CutDirection::CutGreater && outputValue > mCuts.get(nModel, iClass)) {
259+
return false;
260+
}
261+
if (dir == o2::cuts_ml::CutDirection::CutSmaller && outputValue < mCuts.get(nModel, iClass)) {
262+
return false;
263+
}
264+
}
265+
++iClass;
266+
}
267+
return true;
268+
}
269+
213270
protected:
214271
std::vector<o2::ml::OnnxModel> mModels; // OnnxModel objects, one for each bin
215272
uint8_t mNModels = 1; // number of bins
216273
uint8_t mNClasses = 3; // number of model classes
217274
std::vector<double> mBinsLimits = {}; // bin limits of the variable (e.g. pT) used to select which model to use
275+
std::vector<double> mBinsLimitsVar2 = {}; // bin limits of a second variable (e.g. multiplicity) used to select which model to use (not used in this base class)
218276
std::vector<std::string> mPaths = {""}; // paths to the models, one for each bin
219277
std::vector<int> mCutDir = {}; // direction of the cuts on the model scores (no cut is also supported)
220278
o2::framework::LabeledArray<double> mCuts = {}; // array of cut values to apply on the model scores
221279
std::map<std::string, uint8_t> mAvailableInputFeatures; // map of available input features
222280
std::vector<uint8_t> mCachedIndices; // vector of index correspondance between configurables and available input features
281+
uint8_t mNVar1Bins = 1; // number of bins of the first variable (e.g. pT) used to select which model to use
282+
uint8_t mNVar2Bins = 1; // number of bins of the second variable (e.g. multiplicity) used to select which model to use
283+
bool mUse2DBinning = false; // switch to enable/disable 2D binning
223284

224285
virtual void setAvailableInputFeatures() { return; } // method to fill the map of available input features
225286

@@ -239,6 +300,34 @@ class MlResponse
239300
}
240301
return std::distance(mBinsLimits.begin(), std::upper_bound(mBinsLimits.begin(), mBinsLimits.end(), value)) - 1;
241302
}
303+
304+
/// Finds matching bin in mBinsLimits
305+
/// \param value1 e.g. pT
306+
/// \param value2 e.g. multiplicity
307+
/// \return index of the matching bin, used to access mModels
308+
/// \note Accounts for the offset due to mBinsLimits storing bin limits (same convention as needed to configure a histogram axis)
309+
template <typename T1, typename T2>
310+
int findBin2D(T1 const& value1, T2 const& value2)
311+
{
312+
if (!mUse2DBinning) {
313+
LOG(fatal) << "2D ML selection called on a class not configured for 2D bins";
314+
}
315+
if (value1 < mBinsLimits.front()) {
316+
return -1;
317+
}
318+
if (value1 >= mBinsLimits.back()) {
319+
return -1;
320+
}
321+
if (value2 < mBinsLimitsVar2.front()) {
322+
return -1;
323+
}
324+
if (value2 >= mBinsLimitsVar2.back()) {
325+
return -1;
326+
}
327+
int bin1 = std::distance(mBinsLimits.begin(), std::upper_bound(mBinsLimits.begin(), mBinsLimits.end(), value1)) - 1;
328+
int bin2 = std::distance(mBinsLimitsVar2.begin(), std::upper_bound(mBinsLimitsVar2.begin(), mBinsLimitsVar2.end(), value2)) - 1;
329+
return bin2 * mNVar1Bins + bin1;
330+
}
242331
};
243332

244333
} // namespace analysis

0 commit comments

Comments
 (0)