Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions tree/ml/inc/ROOT/ML/RClusterLoader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,27 @@ public:
// --- Shuffled path
// Every cluster contributes a prefix to training and a suffix to validation.
// Cost: Each cluster is read twice per epoch, only when validation split is more than 0.
// TODO(staider) Swicth between prefix or suffix for validation randomly per cluster
// We generate a random boolean value to decide whether the training set gets the prefix
// or suffix of each cluster to ensure better shuffling across runs when splitting.
std::mt19937 g(fSetSeed);
std::uniform_int_distribution<int> coin(0, 1);

for (const RClusterRange &c : fAllClusters) {
const std::size_t sz = c.GetNumEntries();
const std::size_t trainSz = static_cast<std::size_t>((1.0f - fValidationSplit) * sz);
const std::size_t valSz = sz - trainSz;

// Randomly assign prefix or suffix to training
const uint64_t trainIsPrefix = coin(g);
const uint64_t trainStart = trainIsPrefix ? c.start : c.start + static_cast<std::uint64_t>(valSz);
const uint64_t valStart = trainIsPrefix ? c.start + static_cast<std::uint64_t>(trainSz) : c.start;

if (trainSz > 0) {
fTrainingClusters.push_back({c.rdfIdx, c.start, c.start + static_cast<std::uint64_t>(trainSz)});
fTrainingClusters.push_back({c.rdfIdx, trainStart, trainStart + static_cast<std::uint64_t>(trainSz)});
fNumTrainingEntries += trainSz;
}
if (valSz > 0) {
fValidationClusters.push_back({c.rdfIdx, c.start + static_cast<std::uint64_t>(trainSz), c.end});
fValidationClusters.push_back({c.rdfIdx, valStart, valStart + static_cast<std::uint64_t>(valSz)});
fNumValidationEntries += valSz;
}
}
Expand Down Expand Up @@ -392,14 +401,26 @@ public:
std::min(static_cast<std::size_t>(totalFiltered * (1.0f - fValidationSplit)), trainRemaining);
const std::size_t valCount = totalFiltered - trainCount;

// We generate a random boolean value to decide whether the training set gets the prefix
// or suffix of each cluster to ensure better shuffling across runs when splitting.
std::mt19937 g(fSetSeed + fAccumulatedFilteredForTrain); // vary per cluster
std::uniform_int_distribution<int> coin(0, 1);
const uint64_t trainIsPrefix = coin(g);

// The boundary is the raw entry index of the first entry assigned to validation.
// Stable across epochs since the same filter always produces the same ordered entries.
const std::uint64_t boundary = (valCount > 0) ? rdfEntries[trainCount] : endRow;
const std::uint64_t trainBoundaryEntry = trainIsPrefix ? rdfEntries[trainCount] : rdfEntries[valCount];
const std::uint64_t boundary = (valCount > 0) ? trainBoundaryEntry : endRow;

const std::uint64_t trainStart = trainIsPrefix ? startRow : boundary;
const std::uint64_t trainEnd = trainIsPrefix ? boundary : endRow;
const std::uint64_t valStart = trainIsPrefix ? boundary : startRow;
const std::uint64_t valEnd = trainIsPrefix ? endRow : boundary;

if (trainCount > 0)
fTrainingClusters.push_back({rdfIdx, startRow, boundary, trainCount});
fTrainingClusters.push_back({rdfIdx, trainStart, trainEnd, trainCount});
if (valCount > 0)
fValidationClusters.push_back({rdfIdx, boundary, endRow, valCount});
fValidationClusters.push_back({rdfIdx, valStart, valEnd, valCount});

fAccumulatedFilteredForTrain += trainCount;
return trainCount;
Expand Down
Loading