|
| 1 | +/* |
| 2 | +Copyright 2024 Huawei Technologies Co., Ltd. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +
|
| 16 | +@author Toni Boehnlein, Benjamin Lozes, Pal Andras Papp, Raphael S. Steiner |
| 17 | +*/ |
| 18 | + |
| 19 | +#pragma once |
| 20 | + |
| 21 | +#include "AbstractTestSuiteRunner.hpp" |
| 22 | +#include "StringToScheduler/run_partitioner.hpp" |
| 23 | +#include "osp/graph_implementations/adj_list_impl/computational_dag_vector_impl.hpp" |
| 24 | +#include "osp/partitioning/model/partitioning.hpp" |
| 25 | +#include "osp/partitioning/model/partitioning_replication.hpp" |
| 26 | +#include "osp/bsp/model/BspSchedule.hpp" |
| 27 | + |
| 28 | +namespace osp { |
| 29 | + |
| 30 | +class PartitioningStatsModule : public IStatisticModule<Partitioning<HypergraphDefT>> { |
| 31 | + public: |
| 32 | + std::vector<std::string> GetMetricHeaders() const override { return {"Cost", "CutNet"}; } |
| 33 | + |
| 34 | + std::map<std::string, std::string> RecordStatistics(const Partitioning<HypergraphDefT> &partitioning, |
| 35 | + std::ofstream & /*log_stream*/) const override { |
| 36 | + std::map<std::string, std::string> stats; |
| 37 | + stats["Cost"] = std::to_string(partitioning.ComputeConnectivityCost()); |
| 38 | + stats["CutNet"] = std::to_string(partitioning.ComputeCutNetCost()); |
| 39 | + return stats; |
| 40 | + } |
| 41 | +}; |
| 42 | + |
| 43 | +template <typename GraphType> |
| 44 | +class PartitioningTestSuiteRunner : public AbstractTestSuiteRunner<Partitioning<HypergraphDefT>, GraphType> { |
| 45 | + private: |
| 46 | + |
| 47 | + protected: |
| 48 | + ReturnStatus ComputeTargetObjectImpl(const BspInstance<GraphType> &instance, |
| 49 | + std::unique_ptr<Partitioning<HypergraphDefT> > &targetObject, |
| 50 | + const pt::ptree &algoConfig, |
| 51 | + long long &computationTimeMs) override { |
| 52 | + return ReturnStatus::ERROR; //unused |
| 53 | + } |
| 54 | + |
| 55 | + void CreateAndRegisterStatisticModules(const std::string &moduleName) override { |
| 56 | + if (moduleName == "PartitioningStats") { |
| 57 | + this->activeStatsModules_.push_back(std::make_unique<PartitioningStatsModule>()); |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + public: |
| 62 | + PartitioningTestSuiteRunner() : AbstractTestSuiteRunner<Partitioning<HypergraphDefT>, GraphType>() {} |
| 63 | + |
| 64 | + int virtual Run(int argc, char *argv[]) override; |
| 65 | +}; |
| 66 | + |
| 67 | +template <typename GraphType> |
| 68 | +int PartitioningTestSuiteRunner<GraphType>::Run(int argc, char *argv[]) { |
| 69 | + using HypergraphT = HypergraphDefT; |
| 70 | + try { |
| 71 | + this->parser_.ParseArgs(argc, argv); |
| 72 | + } catch (const std::exception &e) { |
| 73 | + std::cerr << "Error parsing command line arguments: " << e.what() << std::endl; |
| 74 | + return 1; |
| 75 | + } |
| 76 | + |
| 77 | + if (!this->ParseCommonConfig()) { |
| 78 | + return 1; |
| 79 | + } |
| 80 | + |
| 81 | + this->SetupLogFile(); |
| 82 | + |
| 83 | + CreateAndRegisterStatisticModules("PartitioningStats"); |
| 84 | + this->SetupStatisticsFile(); |
| 85 | + |
| 86 | + for (const auto &machineEntry : std::filesystem::recursive_directory_iterator(this->machineDirPath_)) { |
| 87 | + if (std::filesystem::is_directory(machineEntry)) { |
| 88 | + this->logStream_ << "Skipping directory " << machineEntry.path().string() << std::endl; |
| 89 | + continue; |
| 90 | + } |
| 91 | + std::string filenameMachine = machineEntry.path().string(); |
| 92 | + std::string nameMachine = filenameMachine.substr(filenameMachine.rfind('/') + 1); |
| 93 | + if (nameMachine.rfind('.') != std::string::npos) { |
| 94 | + nameMachine = nameMachine.substr(0, nameMachine.rfind('.')); |
| 95 | + } |
| 96 | + |
| 97 | + // Temporary hack. Until there is no separate file format for partitioning problem parameters, we abuse |
| 98 | + // bsp arch files: 1st number is number of parts, 2nd is imbalance allowed (percentage), rest is ignored |
| 99 | + BspArchitecture<GraphType> arch; |
| 100 | + if (!file_reader::ReadBspArchitecture(filenameMachine, arch)) { |
| 101 | + this->logStream_ << "Reading architecture file " << filenameMachine << " failed." << std::endl; |
| 102 | + continue; |
| 103 | + } |
| 104 | + this->logStream_ << "Start Machine: " + filenameMachine + "\n"; |
| 105 | + |
| 106 | + |
| 107 | + for (const auto &graphEntry : std::filesystem::recursive_directory_iterator(this->graphDirPath_)) { |
| 108 | + if (std::filesystem::is_directory(graphEntry)) { |
| 109 | + this->logStream_ << "Skipping directory " << graphEntry.path().string() << std::endl; |
| 110 | + continue; |
| 111 | + } |
| 112 | + std::string filenameGraph = graphEntry.path().string(); |
| 113 | + std::string nameGraph = filenameGraph.substr(filenameGraph.rfind('/') + 1); |
| 114 | + if (nameGraph.rfind('.') != std::string::npos) { |
| 115 | + nameGraph = nameGraph.substr(0, nameGraph.rfind('.')); |
| 116 | + } |
| 117 | + this->logStream_ << "Start Hypergraph: " + filenameGraph + "\n"; |
| 118 | + |
| 119 | + bool graphStatus = false; |
| 120 | + GraphType dag; |
| 121 | + graphStatus = file_reader::ReadGraph(filenameGraph, dag); |
| 122 | + |
| 123 | + if (!graphStatus) { |
| 124 | + this->logStream_ << "Reading graph file " << filenameGraph << " failed." << std::endl; |
| 125 | + continue; |
| 126 | + } |
| 127 | + |
| 128 | + PartitioningProblem<HypergraphT> instance(ConvertFromCdagAsHyperdag<HypergraphT, GraphType>(dag), arch.NumberOfProcessors()); |
| 129 | + instance.SetMaxWorkWeightViaImbalanceFactor(static_cast<double>(arch.CommunicationCosts()) / 100.0); |
| 130 | + |
| 131 | + for (auto &algorithmConfigPair : this->parser_.scheduler_) { |
| 132 | + const pt::ptree &algoConfig = algorithmConfigPair.second; |
| 133 | + |
| 134 | + std::string currentAlgoName = algoConfig.get_child("name").get_value<std::string>(); |
| 135 | + this->logStream_ << "Start Algorithm " + currentAlgoName + "\n"; |
| 136 | + |
| 137 | + long long computationTimeMs; |
| 138 | + const auto startTime = std::chrono::high_resolution_clock::now(); |
| 139 | + |
| 140 | + std::pair<HypergraphT::VertexCommWeightType, HypergraphT::VertexCommWeightType> cost; |
| 141 | + ReturnStatus execStatus = RunPartitioner(this->parser_, algoConfig, instance, cost); |
| 142 | + |
| 143 | + const auto finishTime = std::chrono::high_resolution_clock::now(); |
| 144 | + computationTimeMs = std::chrono::duration_cast<std::chrono::milliseconds>(finishTime - startTime).count(); |
| 145 | + |
| 146 | + if (execStatus != ReturnStatus::OSP_SUCCESS && execStatus != ReturnStatus::BEST_FOUND) { |
| 147 | + if (execStatus == ReturnStatus::ERROR) { |
| 148 | + this->logStream_ << "Error computing with " << currentAlgoName << "." << std::endl; |
| 149 | + } else if (execStatus == ReturnStatus::TIMEOUT) { |
| 150 | + this->logStream_ << "Partitioner " << currentAlgoName << " timed out." << std::endl; |
| 151 | + } |
| 152 | + continue; |
| 153 | + } |
| 154 | + |
| 155 | + // currently not writing output to file |
| 156 | + |
| 157 | + if (this->statsOutStream_.is_open()) { |
| 158 | + std::map<std::string, std::string> currentRowValues; |
| 159 | + currentRowValues["Graph"] = nameGraph; |
| 160 | + currentRowValues["Machine"] = nameMachine; |
| 161 | + currentRowValues["Algorithm"] = currentAlgoName; |
| 162 | + currentRowValues["TimeToCompute(ms)"] = std::to_string(computationTimeMs); |
| 163 | + currentRowValues["Cost"] = std::to_string(cost.first); |
| 164 | + currentRowValues["CutNet"] = std::to_string(cost.second); |
| 165 | + |
| 166 | + for (size_t i = 0; i < this->allCsvHeaders_.size(); ++i) { |
| 167 | + this->statsOutStream_ << currentRowValues[this->allCsvHeaders_[i]] << (i == this->allCsvHeaders_.size() - 1 ? "" : ","); |
| 168 | + } |
| 169 | + this->statsOutStream_ << "\n"; |
| 170 | + } |
| 171 | + } |
| 172 | + } |
| 173 | + } |
| 174 | + return 0; |
| 175 | +} |
| 176 | + |
| 177 | +} // namespace osp |
0 commit comments