Skip to content

Commit c9fc3ce

Browse files
authored
Merge pull request #63 from SAFEHR-data/jeremy/opt-outs
Implement opt out
2 parents 0bc1dbd + 1ce5d1f commit c9fc3ce

7 files changed

Lines changed: 202 additions & 54 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dependencies = [
2121
dev = [
2222
"pytest>=9.0.2",
2323
"stablehash==0.3.0",
24+
"types-pika-ts",
25+
"types-psycopg2",
2426
]
2527

2628
[project.scripts]

src/controller.py

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313

1414
logging.basicConfig(format="%(levelname)s:%(asctime)s: %(message)s")
1515
logger = logging.getLogger(__name__)
16-
17-
emap_db = db.starDB()
18-
emap_db.init_query()
19-
emap_db.connect()
16+
logger.setLevel(settings.LOG_LEVEL)
17+
# logger.addFilter(DedupeFilter(window_seconds=60))
2018

2119

2220
class waveform_message:
@@ -42,57 +40,78 @@ def reject_message(ch, delivery_tag, requeue):
4240
logger.warning("Attempting to not acknowledge a message on a closed channel.")
4341

4442

45-
def waveform_callback(ch, method_frame, _header_frame, body):
46-
data = json.loads(body)
47-
try:
48-
location_string = data["mappedLocationString"]
49-
observation_timestamp = data["observationTime"]
50-
source_variable_id = data["sourceVariableId"]
51-
source_channel_id = data["sourceChannelId"]
52-
sampling_rate = data["samplingRate"]
53-
units = data["unit"]
54-
waveform_data = data["numericValues"]
55-
mapped_location_string = data["mappedLocationString"]
56-
except IndexError as e:
57-
reject_message(ch, method_frame.delivery_tag, False)
58-
logger.error(
59-
f"Waveform message {method_frame.delivery_tag} is missing required data {e}."
60-
)
61-
return
43+
class WaveformController:
44+
def __init__(self):
45+
self.emap_db = db.starDB()
46+
self.emap_db.init_query()
47+
self.emap_db.connect()
48+
49+
def waveform_callback(self, ch, method_frame, _header_frame, body):
50+
logger.debug("Message received of length %s", len(body))
51+
data = json.loads(body)
52+
try:
53+
location_string = data["mappedLocationString"]
54+
observation_timestamp = data["observationTime"]
55+
source_variable_id = data["sourceVariableId"]
56+
source_channel_id = data["sourceChannelId"]
57+
sampling_rate = data["samplingRate"]
58+
units = data["unit"]
59+
waveform_data = data["numericValues"]
60+
mapped_location_string = data["mappedLocationString"]
61+
logger.debug(
62+
"Message is for loc %s, var %s, ch %s",
63+
location_string,
64+
source_variable_id,
65+
source_channel_id,
66+
)
67+
except KeyError as e:
68+
reject_message(ch, method_frame.delivery_tag, False)
69+
logger.error(
70+
f"Waveform message {method_frame.delivery_tag} is missing required data {e}."
71+
)
72+
return
6273

63-
observation_time = datetime.fromtimestamp(observation_timestamp, tz=timezone.utc)
64-
lookup_success = True
65-
try:
66-
matched_mrn = emap_db.get_row(location_string, observation_time)
67-
except ValueError:
68-
lookup_success = False
69-
logger.error(
70-
"Ambiguous or non existent match for location %s, obs time %s",
71-
location_string,
72-
observation_time,
73-
exc_info=True,
74+
observation_time = datetime.fromtimestamp(
75+
observation_timestamp, tz=timezone.utc
7476
)
75-
matched_mrn = ("unmatched_mrn", "unmatched_nhs", "unmatched_csn")
76-
except ConnectionError:
77-
logger.error("Database error, will try again", exc_info=True)
78-
reject_message(ch, method_frame.delivery_tag, True)
79-
return
80-
81-
if writer.write_frame(
82-
waveform_data,
83-
source_variable_id,
84-
source_channel_id,
85-
observation_timestamp,
86-
units,
87-
sampling_rate,
88-
mapped_location_string,
89-
matched_mrn[2],
90-
matched_mrn[0],
91-
):
92-
if lookup_success:
93-
ack_message(ch, method_frame.delivery_tag)
94-
else:
77+
lookup_success = True
78+
try:
79+
matched_mrn = self.emap_db.get_row(location_string, observation_time)
80+
except ValueError:
81+
lookup_success = False
82+
logger.error(
83+
"Ambiguous or non existent match for location %s, obs time %s",
84+
location_string,
85+
observation_time,
86+
exc_info=True,
87+
)
88+
matched_mrn = ("unmatched_mrn", "unmatched_nhs", "unmatched_csn", False)
89+
except ConnectionError:
90+
logger.error("Database error, will try again", exc_info=True)
91+
reject_message(ch, method_frame.delivery_tag, True)
92+
return
93+
94+
(mrn, nhs_no, csn, opt_out) = matched_mrn
95+
if opt_out:
96+
logger.info("Research opt-out is set for mrn %s, not writing.", mrn)
9597
reject_message(ch, method_frame.delivery_tag, False)
98+
return
99+
100+
if writer.write_frame(
101+
waveform_data,
102+
source_variable_id,
103+
source_channel_id,
104+
observation_timestamp,
105+
units,
106+
sampling_rate,
107+
mapped_location_string,
108+
csn,
109+
mrn,
110+
):
111+
if lookup_success:
112+
ack_message(ch, method_frame.delivery_tag)
113+
else:
114+
reject_message(ch, method_frame.delivery_tag, False)
96115

97116

98117
def receiver():
@@ -105,18 +124,27 @@ def receiver():
105124
host=settings.RABBITMQ_HOST,
106125
port=settings.RABBITMQ_PORT,
107126
)
127+
logger.info("Connecting to RabbitMQ %s", connection_parameters)
108128
connection = pika.BlockingConnection(connection_parameters)
109129
channel = connection.channel()
110130
channel.basic_qos(prefetch_count=1)
111131

132+
controller = WaveformController()
112133
channel.basic_consume(
113134
queue=settings.RABBITMQ_QUEUE,
114135
auto_ack=False,
115-
on_message_callback=waveform_callback,
136+
on_message_callback=controller.waveform_callback,
116137
)
138+
logger.info("Connected to RabbitMQ, callback configured")
117139
try:
118140
channel.start_consuming()
119141
except KeyboardInterrupt:
142+
logger.warning("Received keyboard interrupt, exiting.")
120143
channel.stop_consuming()
144+
except Exception as e:
145+
logger.error("Received other exception")
146+
logger.error(e)
147+
raise e
121148

149+
logger.info("Closing connection to RabbitMQ")
122150
connection.close()

src/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,6 @@ def get_from_env(env_var, *, default_value=None, setting_name=None, required=Fal
3737
get_from_env("HASHER_API_HOSTNAME")
3838
get_from_env("HASHER_API_PORT")
3939

40+
get_from_env("LOG_LEVEL", default_value="INFO")
41+
4042
get_from_env("INSTANCE_NAME", required=True)

src/sql/mrn_based_on_bed_and_datetime.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ first entry being the most recent.
55
SELECT
66
mn.mrn as mrn,
77
mn.nhs_number as nhs_number,
8-
hv.encounter as csn
8+
hv.encounter as csn,
9+
mn.research_opt_out as research_opt_out
910
FROM {schema_name}.mrn mn
1011
INNER JOIN {schema_name}.hospital_visit hv
1112
ON mn.mrn_id = hv.mrn_id

src/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import logging
2+
import time
3+
4+
5+
class DedupeFilter(logging.Filter):
6+
"""Suppress repeated identical log messages within a time window."""
7+
8+
def __init__(self, window_seconds=60, name=""):
9+
super().__init__(name)
10+
self.window_seconds = window_seconds
11+
self._last_message = None
12+
self._last_time = 0.0
13+
self._dedupe_count = 0
14+
15+
def filter(self, record):
16+
msg = record.getMessage()
17+
now = time.monotonic()
18+
if msg == self._last_message and (now - self._last_time) < self.window_seconds:
19+
self._dedupe_count += 1
20+
return False
21+
self._last_message = msg
22+
self._last_time = now
23+
return True

tests/test_controller.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import json
2+
from datetime import datetime
3+
from unittest.mock import Mock
4+
5+
import pytest
6+
7+
from controller import WaveformController
8+
9+
10+
@pytest.mark.parametrize(
11+
"opt_out",
12+
[True, False],
13+
)
14+
@pytest.mark.parametrize(
15+
"db_connect_failure",
16+
[True, False],
17+
)
18+
@pytest.mark.parametrize(
19+
"bad_data",
20+
[True, False],
21+
)
22+
def test_controller_callback(monkeypatch, opt_out, db_connect_failure, bad_data):
23+
emap_db_mock = Mock()
24+
if db_connect_failure:
25+
emap_db_mock.get_row.side_effect = ConnectionError("mock database error")
26+
else:
27+
emap_db_mock.get_row.return_value = ("mrn", "nhsno", "csn", opt_out)
28+
monkeypatch.setattr("controller.db.starDB", Mock(return_value=emap_db_mock))
29+
30+
write_frame_mock = Mock(return_value=True)
31+
monkeypatch.setattr("controller.writer.write_frame", write_frame_mock)
32+
33+
fake_data = {
34+
"sourceLocationString": "foo",
35+
"mappedLocationString": "loc",
36+
"observationTime": datetime.now().timestamp(),
37+
"sourceVariableId": "27",
38+
"sourceChannelId": "1",
39+
"samplingRate": 50,
40+
"unit": "uV",
41+
"numericValues": "[1,2,3]",
42+
}
43+
if bad_data:
44+
# simulate a missing key
45+
del fake_data["sourceChannelId"]
46+
fake_data_str = json.dumps(fake_data)
47+
controller = WaveformController()
48+
49+
method_frame_mock = Mock()
50+
delivery_tag = 12345
51+
method_frame_mock.delivery_tag = delivery_tag
52+
channel_mock = Mock()
53+
channel_mock.is_open = True
54+
55+
controller.waveform_callback(channel_mock, method_frame_mock, None, fake_data_str)
56+
57+
if not bad_data:
58+
# we at least tried to query the DB
59+
emap_db_mock.get_row.assert_called_once()
60+
61+
if bad_data:
62+
write_frame_mock.assert_not_called()
63+
# db should not even have been queried if data was bad
64+
emap_db_mock.get_row.assert_not_called()
65+
channel_mock.basic_reject.assert_called_once_with(delivery_tag, False)
66+
channel_mock.basic_ack.assert_not_called()
67+
elif db_connect_failure:
68+
# if the DB lookup failed, we should not write anything and requeue the message
69+
write_frame_mock.assert_not_called()
70+
channel_mock.basic_reject.assert_called_once_with(delivery_tag, True)
71+
channel_mock.basic_ack.assert_not_called()
72+
elif opt_out:
73+
# patient has opted out, dump the message
74+
write_frame_mock.assert_not_called()
75+
channel_mock.basic_reject.assert_called_once_with(delivery_tag, False)
76+
channel_mock.basic_ack.assert_not_called()
77+
else:
78+
# happy path
79+
write_frame_mock.assert_called_once()
80+
channel_mock.basic_reject.assert_not_called()
81+
channel_mock.basic_ack.assert_called_once_with(delivery_tag)

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)