Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tmva/tmva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(TMVAUtils
TMVA/BatchGenerator/RChunkLoader.hxx
TMVA/BatchGenerator/RChunkConstructor.hxx
TMVA/BatchGenerator/RFlat2DMatrix.hxx
TMVA/BatchGenerator/RFlat2DMatrixOperators.hxx

SOURCES

Expand Down
73 changes: 18 additions & 55 deletions tmva/tmva/inc/TMVA/BatchGenerator/RBatchGenerator.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ private:

std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
std::unique_ptr<RBatchLoader> fBatchLoader;
std::unique_ptr<RBatchLoader> fTrainingBatchLoader;
std::unique_ptr<RBatchLoader> fValidationBatchLoader;

std::unique_ptr<std::thread> fLoadingThread;

Expand All @@ -88,18 +90,6 @@ private:
std::size_t fNumTrainingChunks;
std::size_t fNumValidationChunks;

std::size_t fLeftoverTrainingBatchSize;
std::size_t fLeftoverValidationBatchSize;

std::size_t fNumFullTrainingBatches;
std::size_t fNumFullValidationBatches;

std::size_t fNumLeftoverTrainingBatches;
std::size_t fNumLeftoverValidationBatches;

std::size_t fNumTrainingBatches;
std::size_t fNumValidationBatches;

// flattened buffers for chunks and temporary tensors (rows * cols)
RFlat2DMatrix fTrainTensor;
RFlat2DMatrix fTrainChunkTensor;
Expand Down Expand Up @@ -141,40 +131,21 @@ public:
fChunkLoader =
std::make_unique<RChunkLoader<Args...>>(f_rdf, fNumEntries, fEntries, fChunkSize, fBlockSize, fValidationSplit,
fCols, vecSizes, vecPadding, fShuffle, fSetSeed);
fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols);
fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols, fNumEntries, fDropRemainder);

// split the dataset into training and validation sets
fChunkLoader->SplitDataset();

// number of training and validation entries after the split
fNumValidationEntries = static_cast<std::size_t>(fValidationSplit * fNumEntries);
fNumTrainingEntries = fNumEntries - fNumValidationEntries;

fLeftoverTrainingBatchSize = fNumTrainingEntries % fBatchSize;
fLeftoverValidationBatchSize = fNumValidationEntries % fBatchSize;

fNumFullTrainingBatches = fNumTrainingEntries / fBatchSize;
fNumFullValidationBatches = fNumValidationEntries / fBatchSize;

fNumLeftoverTrainingBatches = fLeftoverTrainingBatchSize == 0 ? 0 : 1;
fNumLeftoverValidationBatches = fLeftoverValidationBatchSize == 0 ? 0 : 1;

if (dropRemainder) {
fNumTrainingBatches = fNumFullTrainingBatches;
fNumValidationBatches = fNumFullValidationBatches;
}

else {
fNumTrainingBatches = fNumFullTrainingBatches + fNumLeftoverTrainingBatches;
fNumValidationBatches = fNumFullValidationBatches + fNumLeftoverValidationBatches;
}

fTrainingBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols, fNumTrainingEntries, fDropRemainder);
fValidationBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols, fNumValidationEntries, fDropRemainder);

// number of training and validation chunks, calculated in RChunkConstructor
fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
fNumValidationChunks = fChunkLoader->GetNumValidationChunks();

fTrainingChunkNum = 0;
fValidationChunkNum = 0;
}

~RBatchGenerator() { DeActivate(); }
Expand Down Expand Up @@ -226,42 +197,35 @@ public:
/// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RBatchLoader)
void CreateTrainBatches()
{

fChunkLoader->CreateTrainingChunksIntervals();
fTrainingEpochActive = true;
fTrainingChunkNum = 0;
fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, fTrainingChunkNum);
std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum;
fBatchLoader->CreateTrainingBatches(fTrainChunkTensor, lastTrainingBatch, fLeftoverTrainingBatchSize,
fDropRemainder);
fTrainingBatchLoader->CreateBatches(fTrainChunkTensor, fNumTrainingChunks);
fTrainingChunkNum++;
}

/// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches (see RBatchLoader)
void CreateValidationBatches()
{

fChunkLoader->CreateValidationChunksIntervals();
fValidationEpochActive = true;
fValidationChunkNum = 0;
fChunkLoader->LoadValidationChunk(fValidationChunkTensor, fValidationChunkNum);
std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum;
fBatchLoader->CreateValidationBatches(fValidationChunkTensor, lastValidationBatch, fLeftoverValidationBatchSize,
fDropRemainder);
fValidationBatchLoader->CreateBatches(fValidationChunkTensor, fNumValidationChunks);
fValidationChunkNum++;
}

/// \brief Loads a training batch from the queue
RFlat2DMatrix GetTrainBatch()
{
auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
auto batchQueue = fTrainingBatchLoader->GetNumBatchQueue();

// load the next chunk if the queue is empty
if (batchQueue < 1 && fTrainingChunkNum < fNumTrainingChunks) {
fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, fTrainingChunkNum);
std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum;
fBatchLoader->CreateTrainingBatches(fTrainChunkTensor, lastTrainingBatch, fLeftoverTrainingBatchSize,
fDropRemainder);
fTrainingBatchLoader->CreateBatches(fTrainChunkTensor, lastTrainingBatch);
fTrainingChunkNum++;
}

Expand All @@ -270,20 +234,19 @@ public:
}

// Get next batch if available
return fBatchLoader->GetTrainBatch();
return fTrainingBatchLoader->GetBatch();
}

/// \brief Loads a validation batch from the queue
RFlat2DMatrix GetValidationBatch()
{
auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
auto batchQueue = fValidationBatchLoader->GetNumBatchQueue();

// load the next chunk if the queue is empty
if (batchQueue < 1 && fValidationChunkNum < fNumValidationChunks) {
fChunkLoader->LoadValidationChunk(fValidationChunkTensor, fValidationChunkNum);
std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum;
fBatchLoader->CreateValidationBatches(fValidationChunkTensor, lastValidationBatch,
fLeftoverValidationBatchSize, fDropRemainder);
fValidationBatchLoader->CreateBatches(fValidationChunkTensor, lastValidationBatch);
fValidationChunkNum++;
}

Expand All @@ -292,14 +255,14 @@ public:
}

// Get next batch if available
return fBatchLoader->GetValidationBatch();
return fValidationBatchLoader->GetBatch();
}

std::size_t NumberOfTrainingBatches() { return fNumTrainingBatches; }
std::size_t NumberOfValidationBatches() { return fNumValidationBatches; }
std::size_t NumberOfTrainingBatches() { return fTrainingBatchLoader->GetNumBatches(); }
std::size_t NumberOfValidationBatches() { return fValidationBatchLoader->GetNumBatches(); }

std::size_t TrainRemainderRows() { return fLeftoverTrainingBatchSize; }
std::size_t ValidationRemainderRows() { return fLeftoverValidationBatchSize; }
std::size_t TrainRemainderRows() { return fTrainingBatchLoader->GetNumRemainderRows(); }
std::size_t ValidationRemainderRows() { return fValidationBatchLoader->GetNumRemainderRows(); }

bool IsActive() { return fIsActive; }
bool TrainingIsActive() { return fTrainingEpochActive; }
Expand Down
Loading
Loading