|
1 | 1 | """ |
2 | | -Shared functions to be used within a Snakemake workflow for parsing |
| 2 | +Shared functions to be used within a Snakemake workflow for handling |
3 | 3 | workflow configs. |
4 | 4 | """ |
5 | | -import os.path |
| 5 | +import os |
| 6 | +import sys |
| 7 | +import yaml |
6 | 8 | from collections.abc import Callable |
7 | | -from snakemake.io import Wildcards |
8 | 9 | from typing import Optional |
9 | 10 | from textwrap import dedent, indent |
10 | 11 |
|
11 | 12 |
|
| 13 | +# Set search paths for Augur |
| 14 | +if "AUGUR_SEARCH_PATHS" in os.environ: |
| 15 | + print(dedent(f"""\ |
| 16 | + Using existing search paths in AUGUR_SEARCH_PATHS: |
| 17 | +
|
| 18 | + {os.environ["AUGUR_SEARCH_PATHS"]!r} |
| 19 | + """), file=sys.stderr) |
| 20 | +else: |
| 21 | + # Note that this differs from the search paths used in |
| 22 | + # resolve_config_path(). |
| 23 | + # This is the preferred default moving forwards, and the plan is to |
| 24 | + # eventually update resolve_config_path() to use AUGUR_SEARCH_PATHS. |
| 25 | + search_paths = [ |
| 26 | + # User analysis directory |
| 27 | + Path.cwd(), |
| 28 | + |
| 29 | + # Workflow defaults folder |
| 30 | + Path(workflow.basedir) / "defaults", |
| 31 | + |
| 32 | + # Workflow root (contains Snakefile) |
| 33 | + Path(workflow.basedir), |
| 34 | + ] |
| 35 | + |
| 36 | + # This should work for majority of workflows, but we could consider doing a |
| 37 | + # more thorough search for the nextstrain-pathogen.yaml. This would likely |
| 38 | + # replicate how CLI searches for the root.¹ |
| 39 | + # ¹ <https://github.com/nextstrain/cli/blob/d5e184c5/nextstrain/cli/command/build.py#L413-L420> |
| 40 | + repo_root = Path(workflow.basedir) / ".." |
| 41 | + if (repo_root / "nextstrain-pathogen.yaml").is_file(): |
| 42 | + search_paths.extend([ |
| 43 | + # Pathogen repo root |
| 44 | + repo_root, |
| 45 | + ]) |
| 46 | + |
| 47 | + search_paths = [path.resolve() for path in search_paths if path.is_dir()] |
| 48 | + |
| 49 | + os.environ["AUGUR_SEARCH_PATHS"] = ":".join(map(str, search_paths)) |
| 50 | + |
| 51 | + |
12 | 52 | class InvalidConfigError(Exception): |
13 | 53 | pass |
14 | 54 |
|
15 | 55 |
|
16 | | -def resolve_config_path(path: str, defaults_dir: Optional[str] = None) -> Callable[[Wildcards], str]: |
| 56 | +def resolve_config_path(path: str, defaults_dir: Optional[str] = None) -> Callable: |
17 | 57 | """ |
18 | 58 | Resolve a relative *path* given in a configuration value. Will always try to |
19 | 59 | resolve *path* after expanding wildcards with Snakemake's `expand` functionality. |
@@ -75,3 +115,42 @@ def resolve_config_path(path: str, defaults_dir: Optional[str] = None) -> Callab |
75 | 115 | """), " " * 4)) |
76 | 116 |
|
77 | 117 | return _resolve_config_path |
| 118 | + |
| 119 | + |
| 120 | +def write_config(path, section=None): |
| 121 | + """ |
| 122 | + Write Snakemake's 'config' variable, or a section of it, to a file. |
| 123 | +
|
| 124 | + *section* is an optional list of keys to navigate to a specific section of |
| 125 | + config. If provided, only that section will be written. |
| 126 | + """ |
| 127 | + global config |
| 128 | + |
| 129 | + os.makedirs(os.path.dirname(path), exist_ok=True) |
| 130 | + |
| 131 | + data = config |
| 132 | + section_str = "config" |
| 133 | + |
| 134 | + if section: |
| 135 | + # Navigate to the specified section |
| 136 | + for key in section: |
| 137 | + # Error if key doesn't exist |
| 138 | + if key not in data: |
| 139 | + raise Exception(f"ERROR: Key {key!r} not found in {section_str!r}.") |
| 140 | + |
| 141 | + data = data[key] |
| 142 | + section_str += f".{key}" |
| 143 | + |
| 144 | + # Error if value is not a mapping |
| 145 | + if not isinstance(data, dict): |
| 146 | + raise Exception(f"ERROR: {section_str!r} is not a mapping of key/value pairs.") |
| 147 | + |
| 148 | + with open(path, 'w') as f: |
| 149 | + yaml.dump(data, f, sort_keys=False, Dumper=NoAliasDumper) |
| 150 | + |
| 151 | + print(f"Saved {section_str!r} to {path!r}.", file=sys.stderr) |
| 152 | + |
| 153 | + |
| 154 | +class NoAliasDumper(yaml.SafeDumper): |
| 155 | + def ignore_aliases(self, data): |
| 156 | + return True |
0 commit comments