Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 22 additions & 36 deletions include/graph/graph.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#pragma once
#include <algorithm>
#include <chrono>
#include <list>
#include <memory>
#include <queue>
#include <stdexcept>
#include <string>
#include <thread>
#include <unordered_map>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -34,7 +34,7 @@ class Graph {
Tensor* outtenres_;
int start_;
int end_;
std::list<BranchState> branch_list_;
std::unordered_map<int, BranchState> branch_map_;
std::vector<std::vector<int>> in_edges_; // next -> prev
std::vector<std::vector<std::pair<int, int>>> split_distribution_;
int count_used_split_distribution_;
Expand Down Expand Up @@ -118,13 +118,9 @@ class Graph {
if (!layer) {
throw std::invalid_argument("Layer cannot be null");
}
bool layer_exists = false;
for (std::shared_ptr<Layer>& existing_layer : layers_) {
if (existing_layer == layer) {
layer_exists = true;
break;
}
}

int id = layer->getID();
bool layer_exists = (id >= 0 && id < V_ && layers_[id] == layer);

if (!layer_exists) {
layer->setID(V_);
Expand All @@ -144,13 +140,9 @@ class Graph {

void addSingleLayer(const std::shared_ptr<Layer>& layer) {
if (!layer) return;
bool layer_exists = false;
for (const std::shared_ptr<Layer>& existing_layer : layers_) {
if (existing_layer == layer) {
layer_exists = true;
break;
}
}

int id = layer->getID();
bool layer_exists = (id >= 0 && id < V_ && layers_[id] == layer);

if (!layer_exists) {
layer->setID(V_);
Expand Down Expand Up @@ -296,31 +288,25 @@ class Graph {

for (size_t k = 0; k < in_edges_[current_layer].size(); ++k) {
auto target_value = in_edges_[current_layer][k];
auto it = std::find_if(branch_list_.rbegin(), branch_list_.rend(),
[target_value](const BranchState& s) {
return s.ind_layer == target_value;
});

if (it != branch_list_.rend()) {
for (size_t f = 0; f < it->distribution.size(); ++f) {
if (it->distribution[f].first == current_layer) {
bool last_use = (it->count_used_ten == 1);
auto& src = it->give_for_all[it->distribution[f].second];
auto it = branch_map_.find(target_value);

if (it != branch_map_.end()) {
for (size_t f = 0; f < it->second.distribution.size(); ++f) {
if (it->second.distribution[f].first == current_layer) {
bool last_use = (it->second.count_used_ten == 1);
auto& src =
it->second.give_for_all[it->second.distribution[f].second];
if (last_use) {
inten_.push_back(std::move(src));
} else {
inten_.push_back(src);
}
}
}
}

if (it != branch_list_.rend()) {
it->count_used_ten--;
if (it->count_used_ten < 1) {
auto rit = std::next(it).base();
it =
std::reverse_iterator<decltype(rit)>(branch_list_.erase(rit));
it->second.count_used_ten--;
if (it->second.count_used_ten < 1) {
branch_map_.erase(it);
}
}
}
Expand Down Expand Up @@ -375,11 +361,11 @@ class Graph {
}
new_branch.distribution = dis;
}
branch_list_.push_back(std::move(new_branch));
branch_map_[current_layer] = std::move(new_branch);
if (outtenres_ && current_layer == end_ &&
!branch_list_.back().give_for_all.empty() &&
!branch_map_[current_layer].give_for_all.empty() &&
countinout[current_layer].second == 0) {
*outtenres_ = std::move(branch_list_.back().give_for_all[0]);
*outtenres_ = std::move(branch_map_[current_layer].give_for_all[0]);
}

#ifdef ENABLE_STATISTIC_TIME
Expand Down
Loading