Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4da5850
Initial plan
Copilot Jun 25, 2025
6cd0b5f
Add restart mechanism to deployment status updater to fix stuck opera…
Copilot Jun 25, 2025
c20078a
Merge branch 'main' into copilot/fix-4464
marrobi Jun 25, 2025
c7072b9
Add heartbeat monitoring to supervisor function for stuck process det…
Copilot Jun 25, 2025
202e726
Move heartbeat monitoring from resource processor to deployment statu…
Copilot Jun 25, 2025
381bd9c
Fix linting issues and increment API version
Copilot Jun 25, 2025
7c5ff5d
Refactor service bus components to implement heartbeat monitoring and…
marrobi Jun 26, 2025
9329c94
update tests and fix issue.
marrobi Jun 26, 2025
96e39b2
Fix lint.
marrobi Jun 26, 2025
f81e9eb
Merge branch 'main' into copilot/fix-4464
marrobi Jun 26, 2025
75d77dd
remove duplicate tests.
marrobi Jun 26, 2025
b190ab3
Merge branch 'main' of https://github.com/microsoft/AzureTRE into cop…
marrobi Nov 7, 2025
7b78e99
Enhance Service Bus consumer with error handling and heartbeat manage…
marrobi Nov 7, 2025
32d8c75
Enhance Service Bus consumer with error handling and heartbeat manage…
marrobi Nov 7, 2025
681d5ad
Merge branch 'copilot/fix-4464' of https://github.com/microsoft/Azure…
marrobi Nov 7, 2025
b6f7e29
Update tests
marrobi Nov 7, 2025
ba8d1e9
update tests
marrobi Nov 7, 2025
49245fe
Update api_app/service_bus/deployment_status_updater.py
marrobi Nov 7, 2025
37291c3
Define format once for two instrumentors.
marrobi Nov 7, 2025
3840d8c
Merge branch 'copilot/fix-4464' of https://github.com/microsoft/Azure…
marrobi Nov 7, 2025
975ed29
Update api_app/service_bus/deployment_status_updater.py
marrobi Nov 7, 2025
7eac68b
Update api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py
marrobi Nov 7, 2025
ff963ac
Move tempfile import to top and add explanatory comment to except clause
Copilot Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_app/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.25.1"
__version__ = "0.26.0"
4 changes: 2 additions & 2 deletions api_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ async def lifespan(app: FastAPI):
airlockStatusUpdater = AirlockStatusUpdater()
await airlockStatusUpdater.init_repos()

asyncio.create_task(deploymentStatusUpdater.receive_messages())
asyncio.create_task(airlockStatusUpdater.receive_messages())
asyncio.create_task(deploymentStatusUpdater.supervisor_with_heartbeat_check())
asyncio.create_task(airlockStatusUpdater.supervisor_with_heartbeat_check())
yield


Expand Down
11 changes: 8 additions & 3 deletions api_app/service_bus/airlock_request_status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from models.domain.airlock_operations import StepResultStatusUpdateMessage
from core import config, credentials
from resources import strings
from service_bus.service_bus_consumer import ServiceBusConsumer


class AirlockStatusUpdater():
class AirlockStatusUpdater(ServiceBusConsumer):

def __init__(self):
pass
super().__init__("airlock_status_updater")

async def init_repos(self):
self.airlock_request_repo = await AirlockRequestRepository.create()
Expand All @@ -36,9 +37,13 @@ async def receive_messages(self):
try:
current_time = time.time()
polling_count += 1

# Update heartbeat file for supervisor monitoring
self.update_heartbeat()

# Log a heartbeat message every 60 seconds to show the service is still working
if current_time - last_heartbeat_time >= 60:
logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue {polling_count} times in the last minute")
logger.info(f"{config.SERVICE_BUS_STEP_RESULT_QUEUE} queue polled {polling_count} times in the last minute")
last_heartbeat_time = current_time
polling_count = 0

Expand Down
37 changes: 22 additions & 15 deletions api_app/service_bus/deployment_status_updater.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
import time
from typing import Dict, List, Any

from pydantic import ValidationError, parse_obj_as

Expand All @@ -21,21 +21,19 @@
from models.domain.operation import DeploymentStatusUpdateMessage, Operation, OperationStep, Status
from resources import strings
from services.logging import logger, tracer
from service_bus.service_bus_consumer import ServiceBusConsumer


class DeploymentStatusUpdater():
class DeploymentStatusUpdater(ServiceBusConsumer):
def __init__(self):
pass
super().__init__("deployment_status_updater")

async def init_repos(self):
self.operations_repo = await OperationRepository.create()
self.resource_repo = await ResourceRepository.create()
self.resource_template_repo = await ResourceTemplateRepository.create()
self.resource_history_repo = await ResourceHistoryRepository.create()

def run(self, *args, **kwargs):
asyncio.run(self.receive_messages())

async def receive_messages(self):
with tracer.start_as_current_span("deployment_status_receive_messages"):
last_heartbeat_time = 0
Expand All @@ -45,9 +43,12 @@ async def receive_messages(self):
try:
current_time = time.time()
polling_count += 1

# Update heartbeat file for supervisor monitoring
self.update_heartbeat()
# Log a heartbeat message every 60 seconds to show the service is still working
if current_time - last_heartbeat_time >= 60:
logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue {polling_count} times in the last minute")
logger.info(f"{config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue polled {polling_count} times in the last minute")
last_heartbeat_time = current_time
polling_count = 0

Expand All @@ -73,15 +74,15 @@ async def receive_messages(self):
# Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available
logger.debug("No sessions for this process. Will look again...")

except ServiceBusConnectionError:
except ServiceBusConnectionError as e:
# Occasionally there will be a transient / network-level error in connecting to SB.
logger.info("Unknown Service Bus connection error. Will retry...")
logger.warning(f"Service Bus connection error (will retry): {e}")

except Exception as e:
# Catch all other exceptions, log them via .exception to get the stack trace, and reconnect
logger.exception(f"Unknown exception. Will retry - {e}")
logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}")

async def process_message(self, msg):
async def process_message(self, msg) -> bool:
complete_message = False
message = ""

Expand Down Expand Up @@ -115,6 +116,11 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage
try:
# update the op
operation = await self.operations_repo.get_operation_by_id(str(message.operationId))

# Add null safety for operation steps
if not operation.steps:
raise ValueError(f"Operation {message.operationId} has no steps")

step_to_update = None
is_last_step = False

Expand All @@ -128,7 +134,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage
is_last_step = True

if step_to_update is None:
raise f"Error finding step {message.stepId} in operation {message.operationId}"
raise ValueError(f"Step {message.stepId} not found in operation {message.operationId}")

# update the step status
step_to_update.status = message.status
Expand Down Expand Up @@ -159,7 +165,8 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage

# more steps in the op to do?
if is_last_step is False:
assert current_step_index < (len(operation.steps) - 1)
if current_step_index >= len(operation.steps) - 1:
raise ValueError(f"Step index {current_step_index} is the last step in operation (has {len(operation.steps)} steps), but more steps were expected")
next_step = operation.steps[current_step_index + 1]

# catch any errors in updating the resource - maybe Cosmos / schema invalid etc, and report them back to the op
Expand Down Expand Up @@ -255,7 +262,7 @@ def get_failure_status_for_action(self, action: RequestAction):

return status

def create_updated_resource_document(self, resource: dict, message: DeploymentStatusUpdateMessage):
def create_updated_resource_document(self, resource: Dict[str, Any], message: DeploymentStatusUpdateMessage) -> Dict[str, Any]:
"""
Merge the outputs with the resource document to persist
"""
Expand All @@ -268,7 +275,7 @@ def create_updated_resource_document(self, resource: dict, message: DeploymentSt

return resource

def convert_outputs_to_dict(self, outputs_list: [Output]):
def convert_outputs_to_dict(self, outputs_list: List[Output]) -> Dict[str, Any]:
"""
Convert a list of Porter outputs to a dictionary
"""
Expand Down
110 changes: 110 additions & 0 deletions api_app/service_bus/service_bus_consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import asyncio
import os
import tempfile
import time

from services.logging import logger

# Configuration constants for monitoring intervals
HEARTBEAT_CHECK_INTERVAL_SECONDS = 60
HEARTBEAT_STALENESS_THRESHOLD_SECONDS = 300
RESTART_DELAY_SECONDS = 5
SUPERVISOR_ERROR_DELAY_SECONDS = 30


class ServiceBusConsumer:

def __init__(self, heartbeat_file_prefix: str):
# Create a unique identifier for this worker process
self.worker_id = os.getpid()
temp_dir = tempfile.gettempdir()
self.heartbeat_file = os.path.join(temp_dir, f"{heartbeat_file_prefix}_heartbeat_{self.worker_id}.txt")
self.service_name = heartbeat_file_prefix.replace('_', ' ').title()
logger.info(f"Initializing {self.service_name}")

def update_heartbeat(self):
try:
# Ensure directory exists
os.makedirs(os.path.dirname(self.heartbeat_file), exist_ok=True)
with open(self.heartbeat_file, 'w') as f:
f.write(str(time.time()))
except PermissionError:
logger.error(f"Permission denied writing heartbeat to {self.heartbeat_file}")
except OSError as e:
logger.error(f"OS error updating heartbeat: {e}")
except Exception as e:
logger.warning(f"Unexpected error updating heartbeat: {e}")

def check_heartbeat(self, max_age_seconds: int = 300) -> bool:
try:
if not os.path.exists(self.heartbeat_file):
logger.warning("Heartbeat file does not exist")
return False

with open(self.heartbeat_file, 'r') as f:
heartbeat_time = float(f.read().strip())

current_time = time.time()
age = current_time - heartbeat_time

if age > max_age_seconds:
logger.warning(f"Heartbeat is {age:.1f} seconds old, exceeding the limit of {max_age_seconds} seconds")

return age <= max_age_seconds
except (ValueError, IOError) as e:
logger.warning(f"Failed to read heartbeat: {e}")
return False

async def receive_messages_with_restart_check(self):
while True:
try:
logger.info("Starting the receive_messages loop...")
await self.receive_messages()
except Exception as e:
logger.exception(f"receive_messages stopped unexpectedly. Restarting... - {e}")
await asyncio.sleep(RESTART_DELAY_SECONDS)

async def supervisor_with_heartbeat_check(self):
task = None
try:
while True:
try:
# Start the receive_messages task if not running
if task is None or task.done():
if task and task.done():
try:
await task # Check for any exception
except Exception as e:
logger.exception(f"receive_messages task failed: {e}")

logger.info("Starting receive_messages task...")
task = asyncio.create_task(self.receive_messages_with_restart_check())

# Wait before checking heartbeat
await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) # Check every minute

# Check if heartbeat is stale
if not self.check_heartbeat(max_age_seconds=HEARTBEAT_STALENESS_THRESHOLD_SECONDS): # 5 minutes max age
logger.warning("Heartbeat is stale, restarting receive_messages task...")
task.cancel()
try:
await task
except asyncio.CancelledError:
# Expected when cancelling a task - ignore and proceed with restart
pass
task = None
except Exception as e:
logger.exception(f"Supervisor error: {e}")
await asyncio.sleep(SUPERVISOR_ERROR_DELAY_SECONDS)
finally:
# Ensure proper cleanup on shutdown
if task and not task.done():
logger.info("Cleaning up supervisor task...")
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

async def receive_messages(self):
raise NotImplementedError("Subclasses must implement receive_messages()")
46 changes: 45 additions & 1 deletion api_app/services/logging.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import os
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry import trace
from azure.monitor.opentelemetry import configure_azure_monitor

from core.config import APPLICATIONINSIGHTS_CONNECTION_STRING, LOGGING_LEVEL

# Standard log format with worker ID
LOG_FORMAT = '%(asctime)s - Worker %(worker_id)s - %(name)s - %(levelname)s - %(message)s'

UNWANTED_LOGGERS = [
"azure.core.pipeline.policies.http_logging_policy",
"azure.eventhub._eventprocessor.event_processor",
Expand Down Expand Up @@ -45,6 +49,23 @@
"urllib3.connectionpool"
]


class WorkerIdFilter(logging.Filter):
"""
A filter that adds worker_id to all log records.
"""

def __init__(self):
super().__init__()
# Get the process ID as a unique worker identifier
self.worker_id = os.getpid()

def filter(self, record: logging.LogRecord) -> bool:
# Add worker_id as an attribute to the log record
record.worker_id = self.worker_id
return True


logger = logging.getLogger("azuretre_api")
tracer = trace.get_tracer("azuretre_api")

Expand All @@ -57,6 +78,20 @@ def configure_loggers():
logging.getLogger(logger_name).setLevel(logging.CRITICAL)


def apply_worker_id_to_logger(logger_instance):
"""
Apply the worker ID filter to a logger instance.
"""
worker_filter = WorkerIdFilter()
logger_instance.addFilter(worker_filter)

# Update handlers to include worker_id in the format
for handler in logger_instance.handlers:
if isinstance(handler, logging.StreamHandler):
formatter = logging.Formatter(LOG_FORMAT)
handler.setFormatter(formatter)


def initialize_logging() -> logging.Logger:

configure_loggers()
Expand Down Expand Up @@ -87,9 +122,18 @@ def initialize_logging() -> logging.Logger:
LoggingInstrumentor().instrument(
set_logging_format=True,
log_level=logging_level,
tracer_provider=tracer._real_tracer
tracer_provider=tracer._real_tracer,
log_format=LOG_FORMAT
)

# Set up a handler if none exists
if not logger.handlers:
handler = logging.StreamHandler()
logger.addHandler(handler)

# Apply worker ID filter
apply_worker_id_to_logger(logger)

logger.info("Logging initialized with level: %s", LOGGING_LEVEL)

return logger
Loading