Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,24 @@ def _reconstruct_sum(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes:
_set_node_value(tree, node, key, sum(valid) if valid else None, index)


def _reconstruct_sum_array(tree: nx.DiGraph, key: str, fixed_nodes: set | None = None) -> None:
"""Reconstructs ancestral states by vectorized numpy sum over array-valued attributes."""
for node in reversed(list(nx.topological_sort(tree))):
is_fixed = fixed_nodes is not None and node in fixed_nodes
if tree.out_degree(node) == 0 or is_fixed:
continue
child_arrays = [tree.nodes[child][key] for child in tree.successors(node)]
stacked = np.stack(child_arrays)
result = np.nansum(stacked, axis=0)
result[np.all(np.isnan(stacked), axis=0)] = np.nan
Comment on lines +204 to +205
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve NaNs when summing array-valued descendants

Using np.nansum here changes the sum reconstruction semantics by treating missing child values as zero instead of propagating missingness. In the previous per-index implementation (and still in scalar sum), a child value of NaN made the parent NaN for that feature ([NaN, 2] -> NaN), but this branch now returns 2. That silently converts unknown measurements into observed totals and can materially change reconstructed ancestral states for arrays with partial missing data.

Useful? React with 👍 / 👎.

tree.nodes[node][key] = result
# Convert numpy arrays back to lists for compatibility
for node in tree.nodes:
val = tree.nodes[node].get(key)
if isinstance(val, np.ndarray):
tree.nodes[node][key] = val.tolist()


def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None:
"""Reconstructs ancestral by averaging the values of the children."""

Expand Down Expand Up @@ -387,18 +405,26 @@ def ancestral_states(
# If array add to tree as list
if is_array:
length = data.shape[1]
node_attrs = data.apply(lambda row: list(row), axis=1).to_dict()
for node in t.nodes:
if node not in node_attrs:
node_attrs[node] = [None] * length
_remove_node_attributes(t, keys_added[0])
nx.set_node_attributes(t, node_attrs, keys_added[0])
fixed_nodes = None
if tdata.alignment != "leaves":
not_all_nan = ~data.isna().all(axis=1)
fixed_nodes = set(data[not_all_nan].index) - leaves_set
for index in range(length):
_ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index, fixed_nodes)
_remove_node_attributes(t, keys_added[0])
if method == "sum":
node_attrs = dict(zip(data.index, data.to_numpy(dtype=float)))
for node in t.nodes:
if node not in node_attrs:
node_attrs[node] = np.full(length, np.nan)
nx.set_node_attributes(t, node_attrs, keys_added[0])
_reconstruct_sum_array(t, keys_added[0], fixed_nodes)
else:
node_attrs = data.apply(lambda row: list(row), axis=1).to_dict()
for node in t.nodes:
if node not in node_attrs:
node_attrs[node] = [None] * length
nx.set_node_attributes(t, node_attrs, keys_added[0])
for index in range(length):
_ancestral_states(t, keys_added[0], method, costs, missing_state, default_state, index, fixed_nodes)
# If column add to tree as scalar
else:
for key, key_added in zip(keys, keys_added, strict=False):
Expand Down
Loading