|
10 | 10 | from concurrent.futures import ThreadPoolExecutor |
11 | 11 | from dataclasses import dataclass, field |
12 | 12 | from datetime import datetime, timedelta, timezone |
13 | | -from threading import Event, Thread |
| 13 | +from threading import Event, Lock, Thread |
14 | 14 | from types import GeneratorType |
15 | 15 | from enum import Enum |
16 | 16 | from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload |
@@ -130,6 +130,73 @@ class _WorkItemStreamOutcome(Enum): |
130 | 130 | SILENT_DISCONNECT = "silent_disconnect" |
131 | 131 |
|
132 | 132 |
|
| 133 | +@dataclass |
| 134 | +class _TrackedChannelState: |
| 135 | + channel: Any |
| 136 | + ref_count: int = 0 |
| 137 | + close_when_released: bool = False |
| 138 | + |
| 139 | + |
| 140 | +class _InFlightChannelTracker: |
| 141 | + def __init__(self): |
| 142 | + self._lock = Lock() |
| 143 | + self._states: dict[int, _TrackedChannelState] = {} |
| 144 | + |
| 145 | + def acquire(self, channel: Any): |
| 146 | + channel_key = id(channel) |
| 147 | + with self._lock: |
| 148 | + state = self._states.get(channel_key) |
| 149 | + if state is None: |
| 150 | + state = _TrackedChannelState(channel=channel) |
| 151 | + self._states[channel_key] = state |
| 152 | + state.ref_count += 1 |
| 153 | + |
| 154 | + released = False |
| 155 | + |
| 156 | + def release() -> None: |
| 157 | + nonlocal released |
| 158 | + if released: |
| 159 | + return |
| 160 | + released = True |
| 161 | + |
| 162 | + channel_to_close = None |
| 163 | + with self._lock: |
| 164 | + state = self._states.get(channel_key) |
| 165 | + if state is None: |
| 166 | + return |
| 167 | + |
| 168 | + state.ref_count -= 1 |
| 169 | + if state.ref_count == 0: |
| 170 | + if state.close_when_released: |
| 171 | + channel_to_close = state.channel |
| 172 | + del self._states[channel_key] |
| 173 | + |
| 174 | + if channel_to_close is not None: |
| 175 | + self._close_channel(channel_to_close) |
| 176 | + |
| 177 | + return release |
| 178 | + |
| 179 | + def retire(self, channel: Any) -> None: |
| 180 | + channel_key = id(channel) |
| 181 | + channel_to_close = None |
| 182 | + with self._lock: |
| 183 | + state = self._states.get(channel_key) |
| 184 | + if state is None: |
| 185 | + channel_to_close = channel |
| 186 | + else: |
| 187 | + state.close_when_released = True |
| 188 | + |
| 189 | + if channel_to_close is not None: |
| 190 | + self._close_channel(channel_to_close) |
| 191 | + |
| 192 | + @staticmethod |
| 193 | + def _close_channel(channel: Any) -> None: |
| 194 | + try: |
| 195 | + channel.close() |
| 196 | + except Exception: |
| 197 | + pass |
| 198 | + |
| 199 | + |
133 | 200 | class VersioningOptions: |
134 | 201 | """Configuration options for orchestrator and activity versioning. |
135 | 202 |
|
@@ -642,6 +709,7 @@ async def _async_run_loop(self): |
642 | 709 | failure_tracker = FailureTracker( |
643 | 710 | threshold=self._resiliency_options.channel_recreate_failure_threshold, |
644 | 711 | ) |
| 712 | + in_flight_channel_tracker = _InFlightChannelTracker() |
645 | 713 |
|
646 | 714 | def get_reconnect_delay_seconds() -> float: |
647 | 715 | return get_full_jitter_delay_seconds( |
@@ -671,6 +739,45 @@ def create_fresh_connection(): |
671 | 739 | current_stub = None |
672 | 740 | raise |
673 | 741 |
|
| 742 | + def wrap_execution(handler, release): |
| 743 | + def wrapped(*args, **kwargs): |
| 744 | + result = handler(*args, **kwargs) |
| 745 | + release() |
| 746 | + return result |
| 747 | + |
| 748 | + return wrapped |
| 749 | + |
| 750 | + def wrap_cancellation(handler, release): |
| 751 | + def wrapped(*args, **kwargs): |
| 752 | + try: |
| 753 | + return handler(*args, **kwargs) |
| 754 | + finally: |
| 755 | + release() |
| 756 | + |
| 757 | + return wrapped |
| 758 | + |
| 759 | + def submit_work_item( |
| 760 | + submit_func, |
| 761 | + handler, |
| 762 | + cancellation_handler, |
| 763 | + request, |
| 764 | + stub, |
| 765 | + completion_token, |
| 766 | + channel, |
| 767 | + ): |
| 768 | + release = in_flight_channel_tracker.acquire(channel) |
| 769 | + try: |
| 770 | + submit_func( |
| 771 | + wrap_execution(handler, release), |
| 772 | + wrap_cancellation(cancellation_handler, release), |
| 773 | + request, |
| 774 | + stub, |
| 775 | + completion_token, |
| 776 | + ) |
| 777 | + except Exception: |
| 778 | + release() |
| 779 | + raise |
| 780 | + |
674 | 781 | def invalidate_connection( |
675 | 782 | *, |
676 | 783 | recreate_channel: bool = False, |
@@ -700,10 +807,7 @@ def invalidate_connection( |
700 | 807 | and self._can_recreate_channel() |
701 | 808 | and (recreate_channel or close_channel) |
702 | 809 | ): |
703 | | - try: |
704 | | - current_channel.close() |
705 | | - except Exception: |
706 | | - pass |
| 810 | + in_flight_channel_tracker.retire(current_channel) |
707 | 811 | current_channel = None |
708 | 812 | current_stub = None |
709 | 813 |
|
@@ -742,7 +846,9 @@ def should_invalidate_connection(rpc_error): |
742 | 846 | continue |
743 | 847 | try: |
744 | 848 | assert current_stub is not None |
| 849 | + assert current_channel is not None |
745 | 850 | stub = current_stub |
| 851 | + channel = current_channel |
746 | 852 | capabilities = [] |
747 | 853 | if self._payload_store is not None: |
748 | 854 | capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) |
@@ -822,36 +928,44 @@ def stream_reader(): |
822 | 928 |
|
823 | 929 | failure_tracker.record_success() |
824 | 930 | if work_item.HasField("orchestratorRequest"): |
825 | | - self._async_worker_manager.submit_orchestration( |
| 931 | + submit_work_item( |
| 932 | + self._async_worker_manager.submit_orchestration, |
826 | 933 | self._execute_orchestrator, |
827 | 934 | self._cancel_orchestrator, |
828 | 935 | work_item.orchestratorRequest, |
829 | 936 | stub, |
830 | 937 | work_item.completionToken, |
| 938 | + channel, |
831 | 939 | ) |
832 | 940 | elif work_item.HasField("activityRequest"): |
833 | | - self._async_worker_manager.submit_activity( |
| 941 | + submit_work_item( |
| 942 | + self._async_worker_manager.submit_activity, |
834 | 943 | self._execute_activity, |
835 | 944 | self._cancel_activity, |
836 | 945 | work_item.activityRequest, |
837 | 946 | stub, |
838 | 947 | work_item.completionToken, |
| 948 | + channel, |
839 | 949 | ) |
840 | 950 | elif work_item.HasField("entityRequest"): |
841 | | - self._async_worker_manager.submit_entity_batch( |
| 951 | + submit_work_item( |
| 952 | + self._async_worker_manager.submit_entity_batch, |
842 | 953 | self._execute_entity_batch, |
843 | 954 | self._cancel_entity_batch, |
844 | 955 | work_item.entityRequest, |
845 | 956 | stub, |
846 | 957 | work_item.completionToken, |
| 958 | + channel, |
847 | 959 | ) |
848 | 960 | elif work_item.HasField("entityRequestV2"): |
849 | | - self._async_worker_manager.submit_entity_batch( |
| 961 | + submit_work_item( |
| 962 | + self._async_worker_manager.submit_entity_batch, |
850 | 963 | self._execute_entity_batch, |
851 | 964 | self._cancel_entity_batch, |
852 | 965 | work_item.entityRequestV2, |
853 | 966 | stub, |
854 | | - work_item.completionToken |
| 967 | + work_item.completionToken, |
| 968 | + channel, |
855 | 969 | ) |
856 | 970 | else: |
857 | 971 | self._logger.warning( |
|
0 commit comments