Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
* Carlo de Donno <carlo.de_donno@roche.com>
* Johannes Hingerl <hingerl.johannes@gene.com>
* Liudeng Zhang <zhangliudeng@gmail.com>
* Jake Dearborn <jakedearborn@gmail.com>
13 changes: 11 additions & 2 deletions src/grelu/interpret/modisco.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run_modisco(
seed=None,
method: str = "deepshap",
correct_grad: bool = False,
attributions: Optional[np.ndarray] = None,
**kwargs,
) -> None:
"""
Expand All @@ -221,13 +222,15 @@ def run_modisco(
batch_size: Batch size to use for model inference
n_shuffles: Number of times to shuffle the background sequences for deepshap.
seed: Random seed
method: Either "deepshap", "saliency" or "ism".
method: Either "deepshap", "saliency", "ism" or "completed".
Attributions must be supplied as an np.ndarray if method is "completed".
correct_grad: If True, gradients will be corrected using the method of Majdandzic et al.
(PMID: 37161475). Only used with method='saliency'.
attributions: An np.ndarray of attributions to use when method is "completed".
**kwargs: Additional arguments to pass to TF-Modisco.

Raises:
NotImplementedError: if the method is neither "deepshap" nor "ism"
NotImplementedError: if the method is neither "deepshap", "saliency", "ism", nor "completed".
"""
from modiscolite.io import save_hdf5
from modiscolite.report import create_modisco_logos, report_motifs
Expand Down Expand Up @@ -280,6 +283,12 @@ def run_modisco(
batch_size=batch_size,
genome=genome,
)
elif method == "completed":
print("Using completed attributions")
if attributions is None:
raise ValueError("Attributions must be provided when method is 'completed'.")
attrs = attributions
attrs = attrs[:, :, start:end]
else:
raise NotImplementedError

Expand Down
Loading