-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathRNNClassification.C
More file actions
130 lines (102 loc) · 5.02 KB
/
RNNClassification.C
File metadata and controls
130 lines (102 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <iostream>
#include <string>
#include <vector>
#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"
#include "TMVA/Factory.h"
#include "TMVA/DataLoader.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"
void RNNClassification()
{
TMVA::Tools::Instance();
TFile *input(0);
TString fname = "./dataset/tmva_class_example.root";
if (!gSystem->AccessPathName( fname )) {
input = TFile::Open( fname ); // check if file in local directory exists
}
else {
//TFile::SetCacheFileDir(".");
//input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
}
if (!input) {
std::cout << "ERROR: could not open data file" << std::endl;
exit(1);
}
std::cout << "--- RNNClassification : Using input file: " << input->GetName() << std::endl;
// Create a ROOT output file where TMVA will store ntuples, histograms, etc.
TString outfileName( "TMVA_DNN.root" );
TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
// Creating the factory object
TMVA::Factory *factory = new TMVA::Factory( "TMVAClassification", outputFile,
"!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification" );
TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
TTree *signalTree = (TTree*)input->Get("TreeS");
TTree *background = (TTree*)input->Get("TreeB");
dataloader->AddVariable( "myvar1 := var1+var2", 'F' );
dataloader->AddVariable( "myvar2 := var1-var2", "Expression 2", "", 'F' );
dataloader->AddVariable( "var3", "Variable 3", "units", 'F' );
dataloader->AddVariable( "var4", "Variable 4", "units", 'F' );
// You can add so-called "Spectator variables", which are not used in the MVA training,
// but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
// input variables, the response values of all trained MVAs, and the spectator variables
dataloader->AddSpectator( "spec1 := var1*2", "Spectator 1", "units", 'F' );
dataloader->AddSpectator( "spec2 := var1*3", "Spectator 2", "units", 'F' );
// global event weights per tree (see below for setting event-wise weights)
Double_t signalWeight = 1.0;
Double_t backgroundWeight = 1.0;
// You can add an arbitrary number of signal or background trees
dataloader->AddSignalTree ( signalTree, signalWeight );
dataloader->AddBackgroundTree( background, backgroundWeight );
// Input Layout
TString inputLayoutString("InputLayout=1|1|4");
// Batch Layout
TString batchLayoutString("BatchLayout=256|1|4");
// General layout.
TString layoutString ("Layout=RNN|128|4|1|0,RESHAPE|1|1|128|FLAT,DENSE|64|TANH,DENSE|2|LINEAR");
// Training strategies.
TString training0("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.5+0.5+0.5, Multithreading=True");
TString training1("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
TString training2("LearningRate=1e-3,Momentum=0.0,Repetitions=1,"
"ConvergenceSteps=20,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True");
TString trainingStrategyString ("TrainingStrategy=");
trainingStrategyString += training0; // + "|" + training1 + "|" + training2;
// General Options.
TString rnnOptions ("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
"WeightInitialization=XAVIERUNIFORM");
rnnOptions.Append(":"); rnnOptions.Append(inputLayoutString);
rnnOptions.Append(":"); rnnOptions.Append(batchLayoutString);
rnnOptions.Append(":"); rnnOptions.Append(layoutString);
rnnOptions.Append(":"); rnnOptions.Append(trainingStrategyString);
rnnOptions.Append(":Architecture=CPU");
TCut mycuts = ""; // for example: TCut mycuts = "abs(var1)<0.5 && abs(var2-0.5)<1";
TCut mycutb = ""; // for example: TCut mycutb = "abs(var1)<0.5";
dataloader->PrepareTrainingAndTestTree( mycuts, mycutb, "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V" );
factory->BookMethod(dataloader, TMVA::Types::kDL, "DNN_CPU", rnnOptions);
// Train MVAs using the set of training events
factory->TrainAllMethods();
// Save the output
outputFile->Close();
std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
std::cout << "==> TMVAClassification is done!" << std::endl;
delete factory;
delete dataloader;
}
int main(int argc, char ** argv)
{
RNNClassification();
return 0;
}