Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
407 changes: 341 additions & 66 deletions src/openjd/sessions/_action_filter.py

Large diffs are not rendered by default.

21 changes: 20 additions & 1 deletion src/openjd/sessions/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
JobParameterValues,
ParameterValue,
ParameterValueType,
RevisionExtensions,
SpecificationRevision,
SymbolTable,
TaskParameterSet,
Expand Down Expand Up @@ -318,6 +319,9 @@ def __init__(
callback: Optional[SessionCallbackType] = None,
os_env_vars: Optional[dict[str, str]] = None,
session_root_directory: Optional[Path] = None,
revision_extensions: RevisionExtensions = RevisionExtensions(
spec_rev=SpecificationRevision.v2023_09, supported_extensions=[]
),
):
"""
Arguments:
Expand Down Expand Up @@ -358,6 +362,8 @@ def __init__(
2. The 'user' (if given) must have at least read permissions to it; and
3. The Working Directory for this Session will be created in the given directory.
If not provided, then the default of gettempdir()/"openjd" is used instead.
revision_extensions (RevisionExtensions): Specification revision and supported extensions
for this session. Defaults to SpecificationRevision.v2023_09 with no extensions.

Raises:
RuntimeError - If the Session initialization fails for any reason.
Expand Down Expand Up @@ -389,9 +395,14 @@ def __init__(
)
self._reset_action_state()

# Store the revision_extensions
self._revision_extensions = revision_extensions

# Set up our logging hook & callback
self._log_filter = ActionMonitoringFilter(
session_id=self._session_id, callback=self._action_log_filter_callback
session_id=self._session_id,
callback=self._action_log_filter_callback,
revision_extensions=revision_extensions,
)
LOG.addFilter(self._log_filter)
self._logger = LoggerAdapter(LOG, extra={"session_id": self._session_id})
Expand Down Expand Up @@ -832,6 +843,14 @@ def run_task(
# =========================
# Helpers

def get_enabled_extensions(self) -> list[str]:
"""Return the list of enabled extensions for this session.

Returns:
list[str]: The list of enabled extensions
"""
return list(self._revision_extensions.extensions)

def _reset_action_state(self) -> None:
"""Reset the internal action state.
This resets to a state equivalent to having nothing running.
Expand Down
10 changes: 9 additions & 1 deletion src/openjd/sessions/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._logging import LoggerAdapter, LogContent, LogExtraInfo
from ._os_checker import is_linux, is_posix, is_windows
from ._session_user import PosixSessionUser, WindowsSessionUser, SessionUser
from ._action_filter import redact_openjd_redacted_env_requests

if is_windows(): # pragma: nocover
from subprocess import CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW # type: ignore
Expand Down Expand Up @@ -274,11 +275,18 @@ def _start_subprocess(self) -> Optional[Popen]:
# https://docs.python.org/2/library/subprocess.html#subprocess.CREATE_NEW_PROCESS_GROUP
popen_args["creationflags"] = CREATE_NEW_PROCESS_GROUP

# Get the command string for logging
cmd_line_for_logger: str
if is_posix():
cmd_line_for_logger = shlex.join(command)
else:
cmd_line_for_logger = list2cmdline(self._args)

cmd_line = list2cmdline(self._args)
# Command line could contain openjd_redacted_env: token lines not yet processed by the
# session logger. If the token appears in the command line we'll redact everything
# in the line after it for the logs. Note that on Linux currently the command including
# args are in a .sh script, so the full argument list isn't printed by default.
cmd_line_for_logger = redact_openjd_redacted_env_requests(cmd_line)
self._logger.info(
"Running command %s",
cmd_line_for_logger,
Expand Down
102 changes: 99 additions & 3 deletions test/openjd/sessions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
from logging import INFO, getLogger
from logging.handlers import QueueHandler
from queue import Empty, SimpleQueue
from typing import Generator
from typing import Generator, Optional
from hashlib import sha256
from unittest.mock import MagicMock
import pytest

from openjd.sessions import PosixSessionUser, WindowsSessionUser, BadCredentialsException
from openjd.sessions._os_checker import is_posix, is_windows
from openjd.sessions._logging import LoggerAdapter
from openjd.sessions._action_filter import ActionMonitoringFilter
from openjd.model import RevisionExtensions, SpecificationRevision

if is_windows():
from openjd.sessions._win32._helpers import ( # type: ignore
Expand Down Expand Up @@ -55,15 +59,107 @@ def pytest_collection_modifyitems(config, items):
config.option.markexpr = mark_expr


def create_unique_logger_name(prefix: str = "", seed: Optional[str] = None) -> str:
"""Create a unique logger name using a hash to avoid collisions.

Args:
prefix: Optional prefix for the logger name
seed: Optional seed string to use for generating the hash

Returns:
A unique logger name
"""
if seed:
h = sha256()
h.update(seed.encode("utf-8"))
suffix = h.hexdigest()[0:32]
else:
charset = string.ascii_letters + string.digits
suffix = "".join(random.choices(charset, k=32))

return f"{prefix}{suffix}"


def build_logger(handler: QueueHandler) -> LoggerAdapter:
charset = string.ascii_letters + string.digits + string.punctuation
name_suffix = "".join(random.choices(charset, k=32))
"""Build a logger for testing purposes.

Args:
handler: The queue handler to attach to the logger

Returns:
A configured LoggerAdapter
"""
name_suffix = create_unique_logger_name()
log = getLogger(".".join((__name__, name_suffix)))
log.setLevel(INFO)
log.addHandler(handler)
return LoggerAdapter(log, extra=dict())


def setup_action_filter_test(
queue_handler: QueueHandler,
session_id: str = "foo",
callback: Optional[MagicMock] = None,
suppress_filtered: bool = False,
enabled_extensions: Optional[list[str]] = None,
) -> tuple[LoggerAdapter, ActionMonitoringFilter, MagicMock]:
"""Set up a test environment for testing ActionMonitoringFilter.

This helper method creates a unique logger name, sets up the ActionMonitoringFilter,
and configures the logger with the filter.

Args:
queue_handler: The QueueHandler to attach to the logger
session_id: The session ID to use for the filter
callback: Optional mock callback to use for the filter
suppress_filtered: Whether to suppress filtered messages
enabled_extensions: Optional list of extensions to enable

Returns:
A tuple containing (logger_adapter, action_filter, callback_mock)

Note:
This helper works for most tests, but for tests that need to verify specific
callback behavior with redacted values, it's better to create the filter and
logger directly in the test. This is because when multiple filters are applied
to the same log message (which can happen when running multiple tests), the
redaction can happen before the callback is invoked, resulting in the callback
receiving redacted values instead of the original values.
"""
# Create a unique logger name WITHOUT using the message as seed
# This ensures each test gets a truly unique logger name
logger_name = create_unique_logger_name(prefix="action_filter_")

# Create a mock callback if one wasn't provided
if callback is None:
callback = MagicMock()

# Create a RevisionExtensions with the provided extensions or an empty set
revision_extensions = RevisionExtensions(
spec_rev=SpecificationRevision.v2023_09, supported_extensions=enabled_extensions or []
)

# Create the filter directly with the provided parameters
action_filter = ActionMonitoringFilter(
session_id=session_id,
callback=callback,
suppress_filtered=suppress_filtered,
revision_extensions=revision_extensions,
)

# Set up the logger
log = getLogger(".".join((__name__, logger_name)))
log.setLevel(INFO)
log.addHandler(queue_handler)
log.addFilter(action_filter)

# Create and return the logger adapter with the session_id
# This is critical for the filter to work properly
logger_adapter = LoggerAdapter(log, extra={"session_id": session_id})

return logger_adapter, action_filter, callback


def collect_queue_messages(queue: SimpleQueue) -> list[str]:
"""Extract the text of messages from a SimpleQueue containing LogRecords"""
messages: list[str] = []
Expand Down
Loading