Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-53622.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add provenance writing support to `MPGraphExecutor` and `SeparablePipelineExecutor`.
12 changes: 8 additions & 4 deletions python/lsst/pipe/base/log_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def from_full(cls, butler: Butler) -> LogCapture:
return cls(butler, butler)

@contextmanager
def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[_LogCaptureContext]:
def capture_logging(
self, task_node: TaskNode, /, quantum: Quantum, records: ButlerLogRecords | None = None
) -> Iterator[_LogCaptureContext]:
"""Configure logging system to capture logs for execution of this task.

Parameters
Expand All @@ -172,6 +174,9 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[
The task definition.
quantum : `~lsst.daf.butler.Quantum`
Single Quantum instance.
records : `lsst.daf.butler.logging.ButlerLogRecords`, optional
Log record container to append to and save. If provided, streaming
mode is disabled (since we'll be saving logs in memory anyway).

Notes
-----
Expand Down Expand Up @@ -213,7 +218,7 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[
) from exc
# Either accumulate into ButlerLogRecords or stream JSON records to
# file and ingest that (ingest is possible only with full butler).
if self.stream_json_logs and self.full_butler is not None:
if self.stream_json_logs and self.full_butler is not None and records is None:
with TemporaryForIngest(self.full_butler, ref) as temporary:
log_handler_file = FileHandler(temporary.ospath)
log_handler_file.setFormatter(JsonLogFormatter())
Expand All @@ -236,7 +241,7 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[
temporary.ingest()

else:
log_handler_memory = ButlerLogRecordHandler()
log_handler_memory = ButlerLogRecordHandler(records)
logging.getLogger().addHandler(log_handler_memory)

try:
Expand All @@ -255,7 +260,6 @@ def capture_logging(self, task_node: TaskNode, /, quantum: Quantum) -> Iterator[
logging.getLogger().removeHandler(log_handler_memory)
if ctx.store:
self._store_log_records(quantum, log_dataset_name, log_handler_memory)
log_handler_memory.records.clear()

else:
with ButlerMDC.set_mdc(mdc):
Expand Down
79 changes: 79 additions & 0 deletions python/lsst/pipe/base/log_on_close.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

__all__ = ("LogOnClose",)

from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from typing import TypeVar

from lsst.utils.logging import VERBOSE

_T = TypeVar("_T")


class LogOnClose:
"""A factory for context manager wrappers that emit a log message when
they are closed.

Parameters
----------
log_func : `~collections.abc.Callable` [ `int`, `str` ]
Callable that takes an integer log level and a string methods and emits
a log message. Note that placeholder formatting is not supported.
"""

def __init__(self, log_func: Callable[[int, str], None]):
self.log_func = log_func

def wrap(
self,
cm: AbstractContextManager[_T],
msg: str,
level: int = VERBOSE,
) -> AbstractContextManager[_T]:
"""Wrap a context manager to log when it is exited.

Parameters
----------
cm : `contextlib.AbstractContextManager`
Context manager to wrap.
msg : `str`
Log message.
level : `int`, optional
Log level.
"""

@contextmanager
def wrapper() -> Iterator[_T]:
with cm as result:
yield result
self.log_func(level, msg)

return wrapper()
66 changes: 51 additions & 15 deletions python/lsst/pipe/base/mp_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,24 @@
import threading
import time
import uuid
from contextlib import ExitStack
from typing import Literal, cast

import networkx

from lsst.daf.butler import DataCoordinate, Quantum
from lsst.daf.butler.cli.cliLog import CliLog
from lsst.daf.butler.logging import ButlerLogRecords
from lsst.utils.threads import disable_implicit_threading

from ._status import InvalidQuantumError, RepeatableQuantumError
from ._task_metadata import TaskMetadata
from .execution_graph_fixup import ExecutionGraphFixup
from .graph import QuantumGraph
from .graph_walker import GraphWalker
from .log_on_close import LogOnClose
from .pipeline_graph import TaskNode
from .quantum_graph import PredictedQuantumGraph, PredictedQuantumInfo
from .quantum_graph import PredictedQuantumGraph, PredictedQuantumInfo, ProvenanceQuantumGraphWriter
from .quantum_graph_executor import QuantumExecutor, QuantumGraphExecutor
from .quantum_reports import ExecutionStatus, QuantumReport, Report

Expand Down Expand Up @@ -515,7 +519,9 @@ def __init__(
start_method = "spawn"
self._start_method = start_method

def execute(self, graph: QuantumGraph | PredictedQuantumGraph) -> None:
def execute(
self, graph: QuantumGraph | PredictedQuantumGraph, *, provenance_graph_file: str | None = None
) -> None:
# Docstring inherited from QuantumGraphExecutor.execute
old_graph: QuantumGraph | None = None
if isinstance(graph, QuantumGraph):
Expand All @@ -525,14 +531,31 @@ def execute(self, graph: QuantumGraph | PredictedQuantumGraph) -> None:
new_graph = graph
xgraph = self._make_xgraph(new_graph, old_graph)
self._report = Report(qgraphSummary=new_graph._make_summary())
try:
if self._num_proc > 1:
self._execute_quanta_mp(xgraph, self._report)
else:
self._execute_quanta_in_process(xgraph, self._report)
except Exception as exc:
self._report.set_exception(exc)
raise
with ExitStack() as exit_stack:
provenance_writer: ProvenanceQuantumGraphWriter | None = None
if provenance_graph_file is not None:
if provenance_graph_file is not None and self._num_proc > 1:
raise NotImplementedError(
"Provenance writing is not implemented for multiprocess execution."
)
provenance_writer = ProvenanceQuantumGraphWriter(
provenance_graph_file,
exit_stack=exit_stack,
log_on_close=LogOnClose(_LOG.log),
predicted=new_graph,
)
try:
if self._num_proc > 1:
self._execute_quanta_mp(xgraph, self._report)
else:
self._execute_quanta_in_process(xgraph, self._report, provenance_writer)
except Exception as exc:
self._report.set_exception(exc)
raise
if provenance_writer is not None:
provenance_writer.write_overall_inputs()
provenance_writer.write_packages()
provenance_writer.write_init_outputs(assume_existence=True)

def _make_xgraph(
self, new_graph: PredictedQuantumGraph, old_graph: QuantumGraph | None
Expand Down Expand Up @@ -576,7 +599,9 @@ def _make_xgraph(
raise MPGraphExecutorError("Updated execution graph has dependency cycle.")
return xgraph

def _execute_quanta_in_process(self, xgraph: networkx.DiGraph, report: Report) -> None:
def _execute_quanta_in_process(
self, xgraph: networkx.DiGraph, report: Report, provenance_writer: ProvenanceQuantumGraphWriter | None
) -> None:
"""Execute all Quanta in current process.

Parameters
Expand All @@ -589,6 +614,9 @@ def _execute_quanta_in_process(self, xgraph: networkx.DiGraph, report: Report) -
`.quantum_graph.PredictedQuantumGraph.quantum_only_xgraph`.
report : `Report`
Object for reporting execution status.
provenance_writer : `.quantum_graph.ProvenanceQuantumGraphWriter` or \
`None`
Object for recording provenance.
"""

def tiebreaker_sort_key(quantum_id: uuid.UUID) -> tuple:
Expand All @@ -606,16 +634,19 @@ def tiebreaker_sort_key(quantum_id: uuid.UUID) -> tuple:

_LOG.debug("Executing %s (%s@%s)", quantum_id, task_node.label, data_id)
fail_exit_code: int | None = None
task_metadata: TaskMetadata | None = None
task_logs = ButlerLogRecords([])
try:
# For some exception types we want to exit immediately with
# exception-specific exit code, but we still want to start
# debugger before exiting if debugging is enabled.
try:
_, quantum_report = self._quantum_executor.execute(
task_node, quantum, quantum_id=quantum_id
execution_result = self._quantum_executor.execute(
task_node, quantum, quantum_id=quantum_id, log_records=task_logs
)
if quantum_report:
report.quantaReports.append(quantum_report)
if execution_result.report:
report.quantaReports.append(execution_result.report)
task_metadata = execution_result.task_metadata
success_count += 1
walker.finish(quantum_id)
except RepeatableQuantumError as exc:
Expand Down Expand Up @@ -701,6 +732,11 @@ def tiebreaker_sort_key(quantum_id: uuid.UUID) -> tuple:
)
failed_count += 1

if provenance_writer is not None:
provenance_writer.write_quantum_provenance(
quantum_id, metadata=task_metadata, logs=task_logs
)

_LOG.info(
"Executed %d quanta successfully, %d failed and %d remain out of total %d quanta.",
success_count,
Expand Down
7 changes: 4 additions & 3 deletions python/lsst/pipe/base/quantum_graph/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,15 @@ def open(
uri: ResourcePathExpression,
header: HeaderModel,
pipeline_graph: PipelineGraph,
indices: dict[uuid.UUID, int],
*,
address_filename: str,
compressor: Compressor,
cdict_data: bytes | None = None,
zstd_level: int = 10,
) -> Iterator[Self]:
uri = ResourcePath(uri)
address_writer = AddressWriter(indices)
address_writer = AddressWriter()
cdict = zstandard.ZstdCompressionDict(cdict_data) if cdict_data is not None else None
compressor = zstandard.ZstdCompressor(level=zstd_level, dict_data=cdict)
with uri.open(mode="wb") as stream:
with zipfile.ZipFile(stream, mode="w", compression=zipfile.ZIP_STORED) as zf:
self = cls(zf, compressor, address_writer, header.int_size)
Expand Down
22 changes: 6 additions & 16 deletions python/lsst/pipe/base/quantum_graph/_multiblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,6 @@ def __str__(self) -> str:
class AddressWriter:
"""A helper object for writing address files for multi-block files."""

indices: dict[uuid.UUID, int] = dataclasses.field(default_factory=dict)
"""Mapping from UUID to internal integer ID.

The internal integer ID must always correspond to the index into the
sorted list of all UUIDs, but this `dict` need not be sorted itself.
"""

addresses: list[dict[uuid.UUID, Address]] = dataclasses.field(default_factory=list)
"""Addresses to store with each UUID.

Expand All @@ -229,18 +222,15 @@ def write(self, stream: IO[bytes], int_size: int) -> None:
int_size : `int`
Number of bytes to use for all integers.
"""
for n, address_map in enumerate(self.addresses):
if not self.indices.keys() >= address_map.keys():
raise AssertionError(
f"Logic bug in quantum graph I/O: address map {n} of {len(self.addresses)} has IDs "
f"{address_map.keys() - self.indices.keys()} not in the index map."
)
indices: set[uuid.UUID] = set()
for address_map in self.addresses:
indices.update(address_map.keys())
stream.write(int_size.to_bytes(1))
stream.write(len(self.indices).to_bytes(int_size))
stream.write(len(indices).to_bytes(int_size))
stream.write(len(self.addresses).to_bytes(int_size))
empty_address = Address()
for key in sorted(self.indices.keys(), key=attrgetter("int")):
row = AddressRow(key, self.indices[key], [m.get(key, empty_address) for m in self.addresses])
for n, key in enumerate(sorted(indices, key=attrgetter("int"))):
row = AddressRow(key, n, [m.get(key, empty_address) for m in self.addresses])
_LOG.debug("Wrote address %s.", row)
row.write(stream, int_size)

Expand Down
13 changes: 3 additions & 10 deletions python/lsst/pipe/base/quantum_graph/_predicted.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,6 @@ def write(
f"Unsupported extension {ext!r} for quantum graph; "
"expected '.qg' (or '.qgraph' to force the old format)."
)
cdict: zstandard.ZstdCompressionDict | None = None
cdict_data: bytes | None = None
quantum_datasets_json: dict[uuid.UUID, bytes] = {}
if len(self.quantum_datasets) < zstd_dict_n_inputs:
Expand All @@ -1807,26 +1806,20 @@ def write(
for quantum_model in itertools.islice(self.quantum_datasets.values(), zstd_dict_n_inputs)
}
try:
cdict = zstandard.train_dictionary(
cdict_data = zstandard.train_dictionary(
zstd_dict_size,
list(quantum_datasets_json.values()),
level=zstd_level,
)
).as_bytes()
except zstandard.ZstdError as err:
warnings.warn(f"Not using a compression dictionary: {err}.")
cdict = None
else:
cdict_data = cdict.as_bytes()
compressor = zstandard.ZstdCompressor(level=zstd_level, dict_data=cdict)
indices = {quantum_id: n for n, quantum_id in enumerate(sorted(self.quantum_datasets.keys()))}
with BaseQuantumGraphWriter.open(
uri,
header=self.header,
pipeline_graph=self.pipeline_graph,
indices=indices,
address_filename="quanta",
compressor=compressor,
cdict_data=cdict_data,
zstd_level=zstd_level,
) as writer:
writer.write_single_model("thin_graph", self.thin_graph)
if self.dimension_data is None:
Expand Down
Loading
Loading