-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask_executor.py
More file actions
430 lines (361 loc) · 14.7 KB
/
task_executor.py
File metadata and controls
430 lines (361 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
#!/usr/bin/env python3
"""
Multi-Task Configuration Parser & Batch Executor
Parses a YAML/JSON task configuration and executes multiple independent tasks
in parallel with timeout enforcement, retry logic, structured logging,
graceful shutdown (SIGINT/SIGTERM), and DAG dependency ordering.
Usage:
python task_executor.py config.yaml [--output results.json] [--log-level DEBUG]
Config format (YAML or JSON):
version: "1.0"
max_parallelism: 3
log_level: "INFO"
tasks:
- id: task1
command: "echo hello"
timeout: 10 # seconds
retries: 2
retry_delay: 1.0 # seconds, exponential backoff base
depends_on: []
Output:
Structured JSON with status, stdout, stderr, exit_code, duration per task.
"""
from __future__ import annotations
import argparse
import concurrent.futures
import json
import logging
import os
import signal
import subprocess
import sys
import time
import uuid
from collections import defaultdict, deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
# Optional YAML support
try:
import yaml
HAS_YAML = True
except ImportError:
HAS_YAML = False
# ---------------------------------------------------------------------------
# Globals / shutdown flag
# ---------------------------------------------------------------------------
_shutdown_requested = False
_executor_ref: concurrent.futures.ThreadPoolExecutor | None = None
def _signal_handler(signum: int, frame: Any) -> None:
global _shutdown_requested
_shutdown_requested = True
logging.warning("Shutdown signal received (%s). Draining active tasks...", signum)
if _executor_ref:
_executor_ref.shutdown(wait=False, cancel_futures=True)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def utc_now() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def load_config(path: str) -> dict[str, Any]:
"""Load task configuration from YAML or JSON file."""
raw = Path(path).read_text(encoding="utf-8")
if path.lower().endswith((".yaml", ".yml")):
if not HAS_YAML:
raise RuntimeError("PyYAML not installed. Install with: pip install pyyaml")
return yaml.safe_load(raw)
return json.loads(raw)
def topological_sort(tasks: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Return tasks in topological order respecting `depends_on` edges.
Raises ValueError if a cycle is detected.
"""
by_id = {t["id"]: t for t in tasks}
in_degree: dict[str, int] = {t["id"]: 0 for t in tasks}
dependents: dict[str, list[str]] = defaultdict(list)
for task in tasks:
for dep in task.get("depends_on") or []:
if dep not in by_id:
raise ValueError(f"Task '{task['id']}' depends on unknown task '{dep}'")
in_degree[task["id"]] += 1
dependents[dep].append(task["id"])
queue: deque[str] = deque(tid for tid, deg in in_degree.items() if deg == 0)
ordered: list[dict[str, Any]] = []
while queue:
tid = queue.popleft()
ordered.append(by_id[tid])
for dependent in dependents[tid]:
in_degree[dependent] -= 1
if in_degree[dependent] == 0:
queue.append(dependent)
if len(ordered) != len(tasks):
cycle_ids = [tid for tid, deg in in_degree.items() if deg > 0]
raise ValueError(f"Cycle detected in task dependencies: {cycle_ids}")
return ordered
# ---------------------------------------------------------------------------
# Task runner
# ---------------------------------------------------------------------------
def run_single_task(
task: dict[str, Any],
logger: logging.Logger,
) -> dict[str, Any]:
"""Execute a single task with timeout + exponential-backoff retry."""
task_id = task["id"]
command = task["command"]
timeout = float(task.get("timeout", 60))
max_retries = int(task.get("retries", 0))
retry_delay = float(task.get("retry_delay", 1.0))
result: dict[str, Any] = {
"id": task_id,
"command": command,
"status": "pending",
"exit_code": None,
"stdout": "",
"stderr": "",
"duration_seconds": 0.0,
"retries_used": 0,
"started_utc": None,
"completed_utc": None,
"error": None,
}
for attempt in range(max_retries + 1):
if _shutdown_requested:
result["status"] = "skipped"
result["error"] = "Shutdown requested before task could run"
return result
if attempt > 0:
backoff = retry_delay * (2 ** (attempt - 1))
logger.info("[%s] Retry %d/%d after %.1fs backoff", task_id, attempt, max_retries, backoff)
time.sleep(backoff)
result["started_utc"] = utc_now()
t0 = time.monotonic()
try:
proc = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=timeout,
)
elapsed = time.monotonic() - t0
result.update({
"exit_code": proc.returncode,
"stdout": proc.stdout,
"stderr": proc.stderr,
"duration_seconds": round(elapsed, 3),
"retries_used": attempt,
"completed_utc": utc_now(),
})
if proc.returncode == 0:
result["status"] = "success"
logger.info("[%s] Completed in %.3fs (exit 0)", task_id, elapsed)
else:
result["status"] = "failed"
logger.warning("[%s] Failed exit=%d after %.3fs", task_id, proc.returncode, elapsed)
if attempt < max_retries:
continue
break
except subprocess.TimeoutExpired:
elapsed = time.monotonic() - t0
result.update({
"status": "timedout",
"duration_seconds": round(elapsed, 3),
"retries_used": attempt,
"completed_utc": utc_now(),
"error": f"Timed out after {timeout}s",
})
logger.error("[%s] Timed out after %.1fs", task_id, timeout)
if attempt < max_retries:
continue
break
except Exception as exc:
elapsed = time.monotonic() - t0
result.update({
"status": "error",
"duration_seconds": round(elapsed, 3),
"retries_used": attempt,
"completed_utc": utc_now(),
"error": str(exc),
})
logger.error("[%s] Unexpected error: %s", task_id, exc)
if attempt < max_retries:
continue
break
return result
# ---------------------------------------------------------------------------
# Batch executor
# ---------------------------------------------------------------------------
def execute_batch(
config: dict[str, Any],
logger: logging.Logger,
) -> dict[str, Any]:
"""Execute all tasks in config respecting parallelism and dependency order."""
global _executor_ref
max_parallelism = int(config.get("max_parallelism", 4))
tasks = config.get("tasks", [])
if not tasks:
raise ValueError("No tasks defined in configuration.")
# Topological sort to respect dependencies
ordered_tasks = topological_sort(tasks)
run_id = str(uuid.uuid4())
started_utc = utc_now()
t_start = time.monotonic()
results: dict[str, dict[str, Any]] = {}
task_futures: dict[str, concurrent.futures.Future] = {} # type: ignore[type-arg]
logger.info("Starting batch run_id=%s tasks=%d parallelism=%d", run_id, len(ordered_tasks), max_parallelism)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallelism) as executor:
_executor_ref = executor
pending = list(ordered_tasks)
submitted: set[str] = set()
while pending or task_futures:
if _shutdown_requested:
logger.warning("Shutdown: cancelling %d pending tasks", len(pending))
for future in task_futures.values():
future.cancel()
break
# Submit tasks whose dependencies are satisfied
for task in list(pending):
if task["id"] in submitted:
continue
deps = task.get("depends_on") or []
# Check all deps are done (success or skipped — not pending or failed)
deps_ok = all(
results.get(dep, {}).get("status") in ("success",)
for dep in deps
)
deps_blocked = any(
results.get(dep, {}).get("status") in ("failed", "timedout", "error", "skipped")
for dep in deps
)
if not deps and task["id"] not in submitted:
# No deps — submit immediately
future = executor.submit(run_single_task, task, logger)
task_futures[task["id"]] = future
submitted.add(task["id"])
pending.remove(task)
elif deps_ok:
future = executor.submit(run_single_task, task, logger)
task_futures[task["id"]] = future
submitted.add(task["id"])
pending.remove(task)
elif deps_blocked:
# Dependency failed — skip this task
logger.warning("[%s] Skipping — dependency failed", task["id"])
results[task["id"]] = {
"id": task["id"],
"command": task.get("command", ""),
"status": "skipped",
"exit_code": None,
"stdout": "",
"stderr": "",
"duration_seconds": 0.0,
"retries_used": 0,
"started_utc": None,
"completed_utc": None,
"error": "Dependency failed",
}
submitted.add(task["id"])
pending.remove(task)
# Collect completed futures
done_ids = [tid for tid, fut in task_futures.items() if fut.done()]
for tid in done_ids:
future = task_futures.pop(tid)
try:
results[tid] = future.result()
except Exception as exc:
results[tid] = {
"id": tid, "status": "error", "error": str(exc),
"exit_code": None, "stdout": "", "stderr": "",
"duration_seconds": 0.0, "retries_used": 0,
"started_utc": None, "completed_utc": utc_now(),
}
if task_futures or pending:
time.sleep(0.05)
_executor_ref = None
total_elapsed = time.monotonic() - t_start
all_results = [results.get(t["id"], {"id": t["id"], "status": "unknown"}) for t in ordered_tasks]
summary = {
"total": len(ordered_tasks),
"success": sum(1 for r in all_results if r.get("status") == "success"),
"failed": sum(1 for r in all_results if r.get("status") in ("failed", "timedout", "error")),
"timedout": sum(1 for r in all_results if r.get("status") == "timedout"),
"skipped": sum(1 for r in all_results if r.get("status") == "skipped"),
}
logger.info(
"Batch complete in %.2fs — success=%d failed=%d timedout=%d skipped=%d",
total_elapsed, summary["success"], summary["failed"], summary["timedout"], summary["skipped"]
)
return {
"run_id": run_id,
"started_utc": started_utc,
"completed_utc": utc_now(),
"total_duration_seconds": round(total_elapsed, 3),
"config_max_parallelism": max_parallelism,
"tasks": all_results,
"summary": summary,
}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Multi-Task Configuration Parser & Batch Executor",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("config", help="Path to YAML or JSON task config file")
parser.add_argument(
"--output", "-o", default="results.json",
help="Output path for JSON results (default: results.json)"
)
parser.add_argument(
"--log-level", default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging verbosity (default: INFO)"
)
parser.add_argument(
"--dry-run", action="store_true",
help="Parse config and show task plan without executing"
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
)
logger = logging.getLogger("task_executor")
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
try:
config = load_config(args.config)
except Exception as exc:
logger.error("Failed to load config '%s': %s", args.config, exc)
return 1
tasks = config.get("tasks", [])
logger.info("Loaded %d task(s) from '%s'", len(tasks), args.config)
if args.dry_run:
try:
ordered = topological_sort(tasks)
except ValueError as exc:
logger.error("Config validation error: %s", exc)
return 1
logger.info("Dry-run: task execution order:")
for t in ordered:
logger.info(" [%s] cmd=%r timeout=%s retries=%s deps=%s",
t["id"], t.get("command", ""), t.get("timeout", 60),
t.get("retries", 0), t.get("depends_on", []))
return 0
try:
report = execute_batch(config, logger)
except Exception as exc:
logger.error("Batch execution error: %s", exc)
return 1
out_path = Path(args.output)
out_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
logger.info("Results written to '%s'", out_path)
failed = report["summary"]["failed"]
return 0 if failed == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())