Skip to content

Commit 2beea45

Browse files
committed
[RF] Optimizations to avoid taking logarithms of Gaussian constraints
Large likelihood fits (e.g. in CMS and ATLAS) often contain many Gaussian constraint terms. In the standard evaluation, these incur unnecessary overhead by computing the PDF value via `exp(...)` only to immediately take the logarithm in the constraint term. This commit introduces a log-space optimization that avoids the redundant exp -> log roundtrip. We need this optimization because it makes a difference for CMS, and they have also implemented something like this in the custom CMS Combine likelihood classes. If we want them to gradually move to "vanilla" RooFit, we need at least the same optimizations.
1 parent 18a86f0 commit 2beea45

14 files changed

Lines changed: 138 additions & 13 deletions

File tree

roofit/batchcompute/res/RooBatchCompute.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,12 @@ class Config {
5858
void setCudaStream(CudaInterface::CudaStream *cudaStream) { _cudaStream = cudaStream; }
5959
CudaInterface::CudaStream *cudaStream() const { return _cudaStream; }
6060

61+
bool takeLog() const { return _takeLog; }
62+
void setTakeLog(bool takeLog) { _takeLog = takeLog; }
63+
6164
private:
6265
CudaInterface::CudaStream *_cudaStream = nullptr;
66+
bool _takeLog = false;
6367
};
6468

6569
enum class Architecture {
@@ -90,6 +94,7 @@ enum Computer {
9094
Gamma,
9195
GaussModelExpBasis,
9296
Gaussian,
97+
LogGaussian,
9398
Identity,
9499
Johnson,
95100
Landau,

roofit/batchcompute/src/ComputeFunctions.cxx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,18 @@ __rooglobal__ void computeGaussian(Batches &batches)
470470
}
471471
}
472472

473+
__rooglobal__ void computeLogGaussian(Batches &batches)
474+
{
475+
auto x = batches.args[0];
476+
auto mean = batches.args[1];
477+
auto sigma = batches.args[2];
478+
for (size_t i = BEGIN; i < batches.nEvents; i += STEP) {
479+
const double arg = x[i] - mean[i];
480+
const double halfBySigmaSq = -0.5 / (sigma[i] * sigma[i]);
481+
batches.output[i] = arg * arg * halfBySigmaSq;
482+
}
483+
}
484+
473485
__rooglobal__ void computeIdentity(Batches &batches)
474486
{
475487
for (size_t i = BEGIN; i < batches.nEvents; i += STEP) {
@@ -938,6 +950,7 @@ std::vector<void (*)(Batches &)> getFunctions()
938950
computeGamma,
939951
computeGaussModelExpBasis,
940952
computeGaussian,
953+
computeLogGaussian,
941954
computeIdentity,
942955
computeJohnson,
943956
computeLandau,

roofit/roofit/inc/RooGaussian.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class RooGaussian : public RooAbsPdf {
5858

5959
double evaluate() const override;
6060
void doEval(RooFit::EvalContext &) const override;
61+
bool canOptimizeLogarithm() const override { return true; }
6162
inline bool canComputeBatchWithCuda() const override { return true; }
6263

6364
private:

roofit/roofit/src/RooGaussian.cxx

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ Plain Gaussian p.d.f
2727

2828
#include <RooFit/Detail/MathFuncs.h>
2929

30-
#include <vector>
31-
3230

3331
////////////////////////////////////////////////////////////////////////////////
3432

@@ -61,10 +59,22 @@ double RooGaussian::evaluate() const
6159

6260
////////////////////////////////////////////////////////////////////////////////
6361
/// Compute multiple values of Gaussian distribution.
64-
void RooGaussian::doEval(RooFit::EvalContext & ctx) const
62+
void RooGaussian::doEval(RooFit::EvalContext &ctx) const
6563
{
66-
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::Gaussian, ctx.output(),
67-
{ctx.at(x), ctx.at(mean), ctx.at(sigma)});
64+
if (ctx.config(this).takeLog()) {
65+
auto output = ctx.output();
66+
if (output.size() == 1) {
67+
// If the ouput size is just one, which is common for constraints,
68+
// calling into RooBatchCompute is not worth its overhead.
69+
output[0] = RooFit::Detail::MathFuncs::logGaussian(ctx.at(x)[0], ctx.at(mean)[0], ctx.at(sigma)[0]);
70+
} else {
71+
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::LogGaussian, output,
72+
{ctx.at(x), ctx.at(mean), ctx.at(sigma)});
73+
}
74+
return;
75+
}
76+
RooBatchCompute::compute(ctx.config(this), RooBatchCompute::Gaussian, ctx.output(),
77+
{ctx.at(x), ctx.at(mean), ctx.at(sigma)});
6878
}
6979

7080
////////////////////////////////////////////////////////////////////////////////

roofit/roofitcore/inc/RooAbsArg.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,10 @@ class RooAbsArg : public TNamed, public RooPrintable {
510510
return false;
511511
};
512512

513-
virtual bool canComputeBatchWithCuda() const { return false; }
513+
/// Information to expose to the RooFit::Evaluator for optimized evaluation:
514+
virtual bool canOptimizeLogarithm() const { return false; }
514515
virtual bool isReducerNode() const { return false; }
516+
virtual bool canComputeBatchWithCuda() const { return false; }
515517

516518
virtual void applyWeightSquared(bool flag);
517519

roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,18 @@ double bernstein(double x, double xmin, double xmax, DoubleArray coefs, int nCoe
8383
return result;
8484
}
8585

86-
/// @brief Function to evaluate an un-normalized RooGaussian.
87-
inline double gaussian(double x, double mean, double sigma)
86+
/// Evaluate the logarithm of an un-normalized Gaussian.
87+
inline double logGaussian(double x, double mean, double sigma)
8888
{
8989
const double arg = x - mean;
9090
const double sig = sigma;
91-
return std::exp(-0.5 * arg * arg / (sig * sig));
91+
return -0.5 * arg * arg / (sig * sig);
92+
}
93+
94+
/// @brief Function to evaluate an un-normalized Gaussian.
95+
inline double gaussian(double x, double mean, double sigma)
96+
{
97+
return std::exp(logGaussian(x, mean, sigma));
9298
}
9399

94100
template <typename DoubleArray>

roofit/roofitcore/inc/RooFit/Detail/RooNLLVarNew.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ class RooNLLVarNew : public RooAbsReal {
5353

5454
void enableBinOffsetting(bool on = true) { _doBinOffset = on; }
5555

56-
void setSimCount(int simCount) { _simCount = simCount; }
56+
void setSimCount(int simCount)
57+
{
58+
_simCount = simCount;
59+
_logSimCount = std::log(static_cast<double>(simCount));
60+
}
5761

5862
RooAbsPdf const &pdf() const { return *_pdf; }
5963
RooAbsReal const &weightVar() const { return *_weightVar; }
@@ -78,6 +82,7 @@ class RooNLLVarNew : public RooAbsReal {
7882
bool _doOffset = false;
7983
bool _doBinOffset = false;
8084
int _simCount = 1;
85+
double _logSimCount = 0.;
8186
std::string _prefix;
8287
std::vector<double> _binw;
8388
mutable ROOT::Math::KahanSum<double> _offset{0.}; ///<! Offset as KahanSum to avoid loss of precision

roofit/roofitcore/inc/RooFit/Detail/RooNormalizedPdf.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class RooNormalizedPdf : public RooAbsPdf {
6969
return _pdf->createExpectedEventsFunc(&_normSet);
7070
}
7171

72+
bool canOptimizeLogarithm() const override { return true; }
7273
bool canComputeBatchWithCuda() const override { return true; }
7374

7475
RooAbsPdf const &pdf() const { return *_pdf; }

roofit/roofitcore/inc/RooRealIntegral.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class RooRealIntegral : public RooAbsReal {
103103
double evaluate() const override ;
104104
bool isValidReal(double value, bool printError=false) const override ;
105105

106+
void doEval(RooFit::EvalContext &) const override;
107+
bool canOptimizeLogarithm() const override { return true; }
108+
106109
bool redirectServersHook(const RooAbsCollection& newServerList,
107110
bool mustReplaceAll, bool nameChange, bool isRecursive) override ;
108111

roofit/roofitcore/src/RooConstraintSum.cxx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ arguments.
2828

2929

3030
#include "RooConstraintSum.h"
31+
#include "RooBatchCompute.h"
3132
#include "RooAbsData.h"
3233
#include "RooAbsReal.h"
3334
#include "RooAbsPdf.h"
@@ -78,12 +79,17 @@ double RooConstraintSum::evaluate() const
7879
return sum;
7980
}
8081

82+
/// Evaluate with the vectorizing CPU backend.
8183
void RooConstraintSum::doEval(RooFit::EvalContext &ctx) const
8284
{
8385
double sum(0);
8486

8587
for (const auto comp : _set1) {
86-
sum -= std::log(ctx.at(comp)[0]);
88+
// We only need to take the logarithm if the server didn't do it already:
89+
if (!ctx.config(comp).takeLog())
90+
sum -= std::log(ctx.at(comp)[0]);
91+
else
92+
sum -= ctx.at(comp)[0];
8793
}
8894

8995
ctx.output()[0] = sum;

0 commit comments

Comments
 (0)