Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6ec4f62
refactor: separate statistic computation
tristan-f-r Oct 10, 2025
9987189
fix: correct tuple assumption
tristan-f-r Oct 10, 2025
25eef5e
fix: stably use graph statistic values
tristan-f-r Oct 10, 2025
cb373c1
style: fmt
tristan-f-r Oct 30, 2025
4640bc0
feat: init intervals and heuristics
tristan-f-r Oct 30, 2025
47a9e26
Merge branch 'main' into lazy-stats
tristan-f-r Oct 30, 2025
898d568
style: specify zip strict
tristan-f-r Oct 30, 2025
b307f84
Merge branch 'lazy-stats' into heuristics
tristan-f-r Oct 30, 2025
8177ed6
refactor: use heuristic error, mv heuristics outside of main schema file
tristan-f-r Oct 30, 2025
fac1108
fix: proper tokenization
tristan-f-r Oct 30, 2025
2e0d8d0
fix(interval): correct parsing
tristan-f-r Oct 30, 2025
183c3ad
fix(interval): correct other parsing mistakes
tristan-f-r Oct 30, 2025
0b6e01f
feat: integrate heuristics
tristan-f-r Nov 6, 2025
33e004f
fix: drop random code
tristan-f-r Nov 6, 2025
c675ece
fix: make undirected for determining number of connected components
tristan-f-r Nov 6, 2025
6a9a0f3
Merge branch 'lazy-stats' into heuristics
tristan-f-r Nov 6, 2025
1cdaf12
fix: specify heuristics in wrapping config object
tristan-f-r Nov 6, 2025
7b290dc
feat: interval and heuristic testing
tristan-f-r Nov 6, 2025
4844fd6
style: fmt
tristan-f-r Nov 6, 2025
3c81d05
Merge branch 'main' into lazy-stats
tristan-f-r Jan 13, 2026
1ca730e
feat: snakemake-based summary generation
tristan-f-r Jan 13, 2026
d67186d
fix(Snakefile): use parse_output for edgelist parsing
tristan-f-r Jan 13, 2026
fd483c3
fix: parse edgelist with rank, embed header skip inside from_edgelist
tristan-f-r Jan 13, 2026
fd5046f
style: fmt
tristan-f-r Jan 13, 2026
79cf748
chore: mention statistics_files param
tristan-f-r Jan 13, 2026
339d915
Merge branch 'hash' into lazy-stats
tristan-f-r Jan 31, 2026
85e0ea8
docs: more info on summary & statistics
tristan-f-r Feb 14, 2026
804849a
style: fmt
tristan-f-r Feb 14, 2026
cf3c6a0
Merge branch 'hash' into lazy-stats
tristan-f-r Feb 14, 2026
0f7acca
Merge remote-tracking branch 'upstream/main' into lazy-stats
tristan-f-r Mar 19, 2026
ae61e57
Merge branch 'umain' into generate-all-inputs
tristan-f-r Apr 17, 2026
b038ecf
Merge branch 'main' into lazy-stats
tristan-f-r Apr 24, 2026
4fe949d
refactor: use dictionaries instead of a flat list
tristan-f-r Apr 25, 2026
9acf7c0
Merge branch 'lazy-stats' into heuristics
tristan-f-r Apr 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ import os
from spras import runner
import shutil
import yaml
from spras.dataset import Dataset
from spras.evaluation import Evaluation
from spras.analysis import ml, summary, cytoscape
from spras.config.revision import detach_spras_revision
import spras.config.config as _config
from spras.dataset import Dataset
from spras.evaluation import Evaluation
from spras.statistics import from_output_pathway, statistics_computation, statistics_options

# Snakemake updated the behavior in the 6.5.0 release https://github.com/snakemake/snakemake/pull/1037
# and using the wrong separator prevents Snakemake from matching filenames to the rules that can produce them
Expand Down Expand Up @@ -292,6 +293,9 @@ rule parse_output:
params = reconstruction_params(wildcards.algorithm, wildcards.params).copy()
params['dataset'] = input.dataset_file
runner.parse_output(detach_spras_revision(_config.config.immutable_files, wildcards.algorithm), input.raw_file, output.standardized_file, params)
# TODO: cache heuristics result, store partial heuristics configuration file
# to allow this rule to update when heuristics change
_config.config.heuristics.validate_graph_from_file(output.standardized_file)

# TODO: reuse in the future once we make summary work for mixed graphs. See https://github.com/Reed-CompBio/spras/issues/128
# Collect summary statistics for a single pathway
Expand All @@ -312,18 +316,48 @@ rule viz_cytoscape:
run:
cytoscape.run_cytoscape(input.pathways, output.session, container_settings)

# We generate new Snakemake rules for every statistic
# to allow parallel and lazy computation of individual statistics
for keys in statistics_computation.keys():
pythonic_name = 'generate_' + '_and_'.join([key.lower().replace(' ', '_') for key in keys])
rule:
# (See https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#procedural-rule-definition)
name: pythonic_name
input: pathway_file = rules.parse_output.output.standardized_file
output: [SEP.join([out_dir, '{dataset}-{algorithm}-{params}', 'statistics', f'{key}.txt']) for key in keys]
# It is very tempting to use `.items()` instead of `.keys()` above, but
# We instead need to pass keys in via parameters, else the job would use the latest values in the statistics_computation.
# More info is in the procedural rule link ab
params: statistics_names=keys
run:
(Path(input.pathway_file).parent / 'statistics').mkdir(exist_ok=True)
graph = from_output_pathway(input.pathway_file)
for computed, output in zip(statistics_computation[params.statistics_names](graph), output):
Path(output).write_text(str(computed))

# We isolate this to a separate input function, as we want to preserve the dictionary structure
def summary_files(wildcards):
return {
algorithm_param: expand(
'{out_dir}{sep}{dataset}-{algorithm_param}{sep}statistics{sep}{statistic}.txt',
out_dir=out_dir, sep=SEP, algorithm_param=algorithm_param, statistic=statistics_options,
dataset=wildcards.dataset
) for algorithm_param in algorithms_with_params
}

# Write a single summary table for all pathways for each dataset
rule summary_table:
input:
# Collect all pathways generated for the dataset
pathways = expand('{out_dir}{sep}{{dataset}}-{algorithm_params}{sep}pathway.txt', out_dir=out_dir, sep=SEP, algorithm_params=algorithms_with_params),
dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle'])
dataset_file = SEP.join([out_dir, 'dataset-{dataset}-merged.pickle']),
# Collect all possible statistics from the `summary_files` dictionary-based input function
statistics = lambda wildcards: flatten(list(summary_files(wildcards).values()))
output: summary_table = SEP.join([out_dir, '{dataset}-pathway-summary.txt'])
run:
# Load the node table from the pickled dataset file
node_table = Dataset.from_file(input.dataset_file).node_table
summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params)
summary_df = summary.summarize_networks(input.pathways, node_table, algorithm_params, algorithms_with_params, summary_files(wildcards))
summary_df.to_csv(output.summary_table, sep='\t', index=False)

# Cluster the output pathways for each dataset
Expand Down
129 changes: 22 additions & 107 deletions spras/analysis/summary.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import ast
import itertools
import json
import os
from pathlib import Path
from statistics import median
from typing import Iterable
from typing import Iterable, Mapping

import networkx as nx
import pandas as pd

from spras.statistics import from_output_pathway, statistics_options


def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict],
algo_with_params: list[str]) -> pd.DataFrame:
algo_with_params: list[str], statistics_files: Mapping[str, Iterable[str | os.PathLike]]) -> pd.DataFrame:
"""
Generate a table that aggregates summary information about networks in file_paths, including which nodes are present
in node_table columns. Network directionality is ignored and all edges are treated as undirected. The order of the
Expand All @@ -18,6 +21,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg
@param algo_params: a nested dict mapping algorithm names to dicts that map parameter hashes to parameter
combinations.
@param algo_with_params: a list of <algorithm>-params-<params_hash> combinations
@param statistics_files: a dictionary from algo_with_params to lists of statistic files with the computed statistics.
@return: pandas DataFrame with summary information
"""
# Ensure that NODEID is the first column
Expand All @@ -40,52 +44,22 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg

# Iterate through each network file path
for index, file_path in enumerate(sorted(file_paths)):
with open(file_path, 'r') as f:
lines = f.readlines()[1:] # skip the header line

# directed or mixed graphs are parsed and summarized as an undirected graph
nw = nx.read_edgelist(lines, data=(('weight', float), ('Direction', str)))
nw = from_output_pathway(file_path)

# Save the network name, number of nodes, number edges, and number of connected components
nw_name = str(file_path)
number_nodes = nw.number_of_nodes()
Comment thread
tristan-f-r marked this conversation as resolved.
number_edges = nw.number_of_edges()
ncc = nx.number_connected_components(nw)

# Save the max/median degree, average clustering coefficient, and density
if number_nodes == 0:
max_degree = 0
median_degree = 0.0
density = 0.0
else:
degrees = [deg for _, deg in nw.degree()]
max_degree = max(degrees)
median_degree = median(degrees)
density = nx.density(nw)

cc = list(nx.connected_components(nw))
# Save the max diameter
# Use diameter only for components with ≥2 nodes (singleton components have diameter 0)
diameters = [
nx.diameter(nw.subgraph(c).copy()) if len(c) > 1 else 0
for c in cc
]
max_diameter = max(diameters, default=0)

# Save the average path lengths
# Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0)
avg_path_lengths = [
nx.average_shortest_path_length(nw.subgraph(c).copy()) if len(c) > 1 else 0.0
for c in cc
# We use ast.literal_eval here to convert statistic file outputs to ints or floats depending on their string representation.
# (e.g. "5.0" -> float(5.0), while "5" -> int(5).)
graph_statistics = [
ast.literal_eval(Path(file).read_text()) for file in
# along with sorting to keep the output stable (this happens again)
sorted(statistics_files[algo_with_params[index]], key=lambda x: statistics_options.index(Path(x).stem))
]

if len(avg_path_lengths) != 0:
avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths)
else:
avg_path_len = 0.0

# Initialize list to store current network information
cur_nw_info = [nw_name, number_nodes, number_edges, ncc, density, max_degree, median_degree, max_diameter, avg_path_len]
cur_nw_info = [nw_name, *graph_statistics]

# Iterate through each node property and save the intersection with the current network
for node_list in nodes_by_col:
Expand All @@ -107,8 +81,13 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg
# Save the current network information to the network summary list
nw_info.append(cur_nw_info)

# Get the list of statistic names by their file names (via finding all requested statistics in the provided files)
current_statistics_options = sorted(
set(Path(file).stem for file in itertools.chain(*statistics_files.values())),
key=lambda x: statistics_options.index(x)
)
# Prepare column names
col_names = ['Name', 'Number of nodes', 'Number of edges', 'Number of connected components', 'Density', 'Max degree', 'Median degree', 'Max diameter', 'Average path length']
col_names = ['Name', *current_statistics_options]
col_names.extend(nodes_by_col_labs)
col_names.append('Parameter combination')

Expand All @@ -120,67 +99,3 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg
)

return nw_info


def degree(g):
return dict(g.degree)

# TODO: redo .run code to work on mixed graphs
# stats is just a list of functions to apply to the graph.
# They should take as input a networkx graph or digraph but may have any output.
# stats = [degree, nx.clustering, nx.betweenness_centrality]


# def produce_statistics(g: nx.Graph, s=None) -> dict:
# global stats
# if s is not None:
# stats = s
# d = dict()
# for s in stats:
# sname = s.__name__
# d[sname] = s(g)
# return d


# def load_graph(path: str) -> nx.Graph:
# g = nx.read_edgelist(path, data=(('weight', float), ('Direction',str)))
# return g


# def save(data, pth):
# fout = open(pth, 'w')
# fout.write('#node\t%s\n' % '\t'.join([s.__name__ for s in stats]))
# for node in data[stats[0].__name__]:
# row = [data[s.__name__][node] for s in stats]
# fout.write('%s\t%s\n' % (node, '\t'.join([str(d) for d in row])))
# fout.close()


# def run(infile: str, outfile: str) -> None:
# """
# run function that wraps above functions.
# """
# # if output directory doesn't exist, make it.
# outdir = os.path.dirname(outfile)
# if not os.path.exists(outdir):
# os.makedirs(outdir)

# # load graph, produce stats, and write to human-readable file.
# g = load_graph(infile)
# dat = produce_statistics(g)
# save(dat, outfile)


# def main(argv):
# """
# for testing
# """
# g = load_graph(argv[1])
# print(g.nodes)
# dat = produce_statistics(g)
# print(dat)
# save(dat, argv[2])


# if __name__ == '__main__':
# main(sys.argv)
2 changes: 2 additions & 0 deletions spras/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, raw_config: dict[str, Any]):
self.hash_length = parsed_raw_config.hash_length
# Container settings used by PRMs.
self.container_settings = ProcessedContainerSettings.from_container_settings(parsed_raw_config.containers, self.hash_length)
# The heuristic handler
self.heuristics = parsed_raw_config.heuristics
# A nested dict mapping algorithm names to dicts that map parameter hashes to parameter combinations.
# Only includes algorithms that are set to be run with 'include: true'.
self.algorithm_params: dict[str, dict[str, Any]] = dict()
Expand Down
105 changes: 105 additions & 0 deletions spras/config/heuristics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os

import networkx as nx
from pydantic import BaseModel, ConfigDict

from spras.interval import Interval
from spras.statistics import compute_statistics, statistics_options

all = ['GraphHeuristicsError', 'GraphHeuristic']

class GraphHeuristicsError(RuntimeError):
"""
Represents an error arising from a graph algorithm output
not meeting the necessary graph heuristisc.
"""
failed_heuristics: list[tuple[str, float | int, list[Interval]]]

@staticmethod
def format_failed_heuristic(heuristic: tuple[str, float | int, list[Interval]]) -> str:
name, desired, intervals = heuristic
if len(intervals) == 1:
interval_string = str(intervals[0])
else:
formatted_intervals = ", ".join([str(interval) for interval in intervals])
interval_string = f"one of the intervals ({formatted_intervals})"
return f"{name} expected {desired} in interval {interval_string}"
Comment on lines +25 to +26
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This text doesn't quite match up. You could get "in interval one of the intervals..."

@staticmethod
def to_string(failed_heuristics: list[tuple[str, float | int, list[Interval]]]):
formatted_heuristics = [
GraphHeuristicsError.format_failed_heuristic(heuristic) for heuristic in failed_heuristics
]

formatted_heuristics = "\n".join([f"- {formatted_heuristics}" for heuristic in formatted_heuristics])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use a different character besides - like * for the list? I'm trying to imagine whether we could ever have a leading negative here in formatted_heuristics that would be confusing.

return f"The following heuristics failed:\n{formatted_heuristics}"

def __init__(self, failed_heuristics: list[tuple[str, float | int, list[Interval]]]):
super().__init__(GraphHeuristicsError.to_string(failed_heuristics))

self.failed_heuristics = failed_heuristics

def __str__(self) -> str:
return GraphHeuristicsError.to_string(self.failed_heuristics)

class GraphHeuristics(BaseModel):
number_of_nodes: Interval | list[Interval] = []
number_of_edges: Interval | list[Interval] = []
number_of_connected_components: Interval | list[Interval] = []
density: Interval | list[Interval] = []

max_degree: Interval | list[Interval] = []
median_degree: Interval | list[Interval] = []
max_diameter: Interval | list[Interval] = []
average_path_length: Interval | list[Interval] = []

def validate_graph(self, graph: nx.DiGraph):
statistics_dictionary = {
'Number of nodes': self.number_of_nodes,
'Number of edges': self.number_of_edges,
'Number of connected components': self.number_of_connected_components,
'Density': self.density,
'Max degree': self.max_degree,
'Median degree': self.median_degree,
'Max diameter': self.max_diameter,
'Average path length': self.average_path_length
}

# quick assert: is statistics_dictionary exhaustive?
assert set(statistics_dictionary.keys()) == set(statistics_options)

stats = compute_statistics(
graph,
list(k for k, v in statistics_dictionary.items() if not isinstance(v, list) or len(v) != 0)
)

failed_heuristics: list[tuple[str, float | int, list[Interval]]] = []
for key, value in stats.items():
intervals = statistics_dictionary[key]
if not isinstance(intervals, list): intervals = [intervals]

for interval in intervals:
if not interval.mem(value):
failed_heuristics.append((key, value, intervals))
break

if len(failed_heuristics) != 0:
raise GraphHeuristicsError(failed_heuristics)

model_config = ConfigDict(extra='forbid')

def validate_graph_from_file(self, path: str | os.PathLike):
"""
Takes in a graph produced by PRM#parse_output,
and throws a GraphHeuristicsError if it fails the heuristics in `self`.
"""
# TODO: re-use from summary.py once we have a mixed/hypergraph library
G: nx.DiGraph = nx.read_edgelist(path, data=(('Rank', str), ('Direction', str)), create_using=nx.DiGraph)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reading in directed edges but summary.py reads undirected edges. Those should be consistent. That is a good reason to use shared code if possible so it doesn't accidentally diverge later.


# We explicitly use `list` here to stop add_edge
# from expanding our iterator infinitely.
for source, target, data in list(G.edges(data=True)):
if data["Direction"] == 'U':
G.add_edge(target, source, data=data)
pass

return self.validate_graph(G)
3 changes: 3 additions & 0 deletions spras/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from spras.config.algorithms import AlgorithmUnion
from spras.config.container_schema import ContainerSettings
from spras.config.dataset import DatasetSchema
from spras.config.heuristics import GraphHeuristics
from spras.config.util import CaseInsensitiveEnum, label_validator

# Most options here have an `include` property,
Expand Down Expand Up @@ -122,6 +123,8 @@ class RawConfig(BaseModel):

reconstruction_settings: ReconstructionSettings

heuristics: GraphHeuristics = GraphHeuristics()

# We include use_attribute_docstrings here to preserve the docstrings
# after attributes at runtime (for future JSON schema generation)
model_config = ConfigDict(extra='forbid', use_attribute_docstrings=True)
Loading
Loading