-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathplotNetworks.py
More file actions
135 lines (117 loc) · 4.15 KB
/
plotNetworks.py
File metadata and controls
135 lines (117 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import networkx as nx
import csv
from pyvis.network import Network
def process_edges(
file_path, G, visited_nodes, label
):
"""Helper function to process edges and add them to the graph."""
# print(f"currently processing {label} edges")
with open(file_path, "r") as file:
csv_reader = csv.reader(file)
next(csv_reader) # Skip header
node_count = len(visited_nodes)
edge_count = G.number_of_edges()
for row in csv_reader:
geneName1 = row[2]
geneName2 = row[3]
if geneName1 not in visited_nodes:
visited_nodes.add(geneName1)
node_count += 1
if geneName2 not in visited_nodes:
visited_nodes.add(geneName2)
node_count += 1
G.add_edge(geneName1, geneName2, label=label)
edge_count += 1
def read_csv(
ppi_path,
reg_path,
):
"""Reads CSV files and constructs a graph with edges labeled as 'ppi' or 'reg'."""
G = nx.MultiDiGraph()
visited_nodes = set()
process_edges(
ppi_path,
G,
visited_nodes,
"ppi"
)
process_edges(
reg_path,
G,
visited_nodes,
"reg"
)
return G
def main(ppi_path, reg_path, output_dir):
G = read_csv(
ppi_path,
reg_path
)
nt = Network('700px', '840px', directed=True)
# nt = Network(directed=True)
G.remove_edges_from(nx.selfloop_edges(G))
# Add nodes to the network graph
for node in G.nodes():
nt.add_node(node, label=node) # Add node label
nt.get_node(node)['color'] = '#cfe2f3'
nt.get_node(node)['shape'] = 'ellipse'
if 'txid224308' in args.output_dir:
nt.get_node('sigA')['color'] = '#fad6a5'
nt.get_node('sigB')['color'] = '#fad6a5'
nt.get_node('katX')['color'] = '#fad6a5'
nt.get_node('katE')['color'] = '#fad6a5'
nt.get_node('ahpC')['color'] = '#fad6a5'
nt.get_node('ahpF')['color'] = '#fad6a5'
nt.get_node('ccpA')['color'] = '#fad6a5'
nt.get_node('perR')['color'] = '#fad6a5'
if 'txid7955' in args.output_dir:
nt.get_node('nanog')['color'] = '#fad6a5'
nt.get_node('sod1')['color'] = '#fad6a5'
nt.get_node('sod2')['color'] = '#fad6a5'
nt.get_node('prdx1')['color'] = '#fad6a5'
nt.get_node('erp44')['color'] = '#fad6a5'
nt.get_node('keap1a')['color'] = '#fad6a5'
nt.get_node('keap1b')['color'] = '#fad6a5'
nt.get_node('smad2')['color'] = '#fad6a5'
nt.get_node('gata1a')['color'] = '#fad6a5'
nt.get_node('sall4')['color'] = '#fad6a5'
nt.get_node('foxh1')['color'] = '#fad6a5' # Skip this subnetwork if not in the specified list
# Track added edges to prevent duplication
seen_ppi_edges = set()
# Iterate over edges
for u, v, data in G.edges(data=True):
label = data.get("label", None) # Get label from edge data
if label == "reg":
nt.add_edge(u, v, color="red", width = 3, arrows="to", arrowStrikethrough=True) # Directed red edge
elif label == "ppi":
if (u, v) not in seen_ppi_edges and (v, u) not in seen_ppi_edges:
nt.add_edge(u, v, color="black", width = 4, arrows="") # Undirected black edge
seen_ppi_edges.add((u, v)) # Mark the pair as seen
seen_ppi_edges.add((v, u)) # Ensure (v, u) isn't added separately
nt.write_html(f'{output_dir}nx.html')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot Stress Subnetworks")
parser.add_argument(
"-u",
"--undirected",
type=str,
help="Path to the undirected/PPI edges input file.",
required=True,
)
parser.add_argument(
"-d",
"--directed",
type=str,
help="Path to the directed/Reg edges input file.",
required=True,
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="Path to the output directory.",
required=True,
)
args = parser.parse_args()
main(args.undirected, args.directed, args.output_dir)