Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions postgresql_watcher/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(
logger (Optional[Logger], optional): Custom logger to use. Defaults to None.
"""
self.update_callback = None
self.parent_conn = None
self.host = host
self.port = port
self.user = user
Expand Down Expand Up @@ -83,7 +82,7 @@ def _create_subscription_process(
self._cleanup_connections_and_processes()

self.parent_conn, self.child_conn = Pipe()
self.subscription_proces = Process(
self.subscription_process = Process(
target=casbin_channel_subscription,
args=(
self.child_conn,
Expand All @@ -109,9 +108,12 @@ def start(
self,
timeout=20, # seconds
):
if not self.subscription_proces.is_alive():
if self.subscription_process is None:
self._create_subscription_process(start_listening=False)

if not self.subscription_process.is_alive():
# Start listening to messages
self.subscription_proces.start()
self.subscription_process.start()
# And wait for the Process to be ready to listen for updates
# from PostgreSQL
timeout_time = time() + timeout
Expand All @@ -124,6 +126,9 @@ def start(
raise PostgresqlWatcherChannelSubscriptionTimeoutError(timeout)
sleep(1 / 1000) # wait for 1 ms

def stop(self):
self._cleanup_connections_and_processes()

def _cleanup_connections_and_processes(self) -> None:
# Clean up potentially existing Connections and Processes
if self.parent_conn is not None:
Expand All @@ -132,8 +137,9 @@ def _cleanup_connections_and_processes(self) -> None:
if self.child_conn is not None:
self.child_conn.close()
self.child_conn = None
if self.subscription_process is not None:
if self.subscription_process is not None and self.subscription_process.pid is not None:
self.subscription_process.terminate()
self.subscription_process.join()
self.subscription_process = None

def set_update_callback(self, update_handler: Optional[Callable[[None], None]]):
Expand Down
24 changes: 23 additions & 1 deletion tests/test_postgresql_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_pg_watcher_init(self):
assert isinstance(pg_watcher.parent_conn, connection.PipeConnection)
else:
assert isinstance(pg_watcher.parent_conn, connection.Connection)
assert isinstance(pg_watcher.subscription_proces, context.Process)
assert isinstance(pg_watcher.subscription_process, context.Process)

def test_update_single_pg_watcher(self):
pg_watcher = get_watcher("test_update_single_pg_watcher")
Expand Down Expand Up @@ -115,6 +115,28 @@ def test_update_handler_not_called(self):
self.assertFalse(main_watcher.should_reload())
self.assertTrue(handler.call_count == 0)

def test_stop_and_restart(self):
channel_name = "test_stop_and_restart"
pg_watcher = get_watcher(channel_name)

# Verify initially started
self.assertTrue(pg_watcher.subscription_process.is_alive())

# Stop the watcher
pg_watcher.stop()
self.assertIsNone(pg_watcher.subscription_process)

# Restart the watcher
pg_watcher.start()

# Verify resources are recreated and process is alive
self.assertTrue(pg_watcher.subscription_process.is_alive())

# Verify it still works after restart
pg_watcher.update()
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
self.assertTrue(pg_watcher.should_reload())


if __name__ == "__main__":
main()