Skip to content

Commit 91d1fa0

Browse files
reference parity
1 parent fbbca42 commit 91d1fa0

7 files changed

Lines changed: 335 additions & 92 deletions

File tree

src/pycyphal/_node.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,22 @@ class Coupling:
271271
substitutions: list[tuple[str, int]]
272272

273273

274+
@dataclass
275+
class SharedSubjectListener:
276+
"""One transport listener shared by all topics bound to the same subject-ID."""
277+
278+
handle: Closable
279+
owners: set[Topic] = field(default_factory=set)
280+
281+
282+
@dataclass
283+
class SharedSubjectWriter:
284+
"""One transport writer shared by all topics bound to the same subject-ID."""
285+
286+
handle: SubjectWriter
287+
owners: set[Topic] = field(default_factory=set)
288+
289+
274290
@dataclass
275291
class PublishTracker:
276292
"""Tracks a pending reliable publication awaiting ACKs."""
@@ -365,34 +381,32 @@ def tag_seqno(self, tag: int) -> int:
365381

366382
def ensure_writer(self) -> SubjectWriter:
367383
if self.pub_writer is None:
368-
self.pub_writer = self._node.transport.subject_advertise(self.subject_id)
369-
_logger.info("Writer created for '%s' sid=%d", self._name, self.subject_id)
384+
sid = self.subject_id
385+
self.pub_writer = self._node.acquire_subject_writer(self, sid)
386+
_logger.info("Writer acquired for '%s' sid=%d", self._name, sid)
370387
return self.pub_writer
371388

372389
def ensure_listener(self) -> None:
373390
if self.sub_listener is None and self.couplings:
374391
sid = self.subject_id
375-
376-
def handler(arrival: TransportArrival) -> None:
377-
self._node.on_subject_arrival(sid, arrival)
378-
379-
self.sub_listener = self._node.transport.subject_listen(sid, handler)
380-
_logger.info("Listener created for '%s' sid=%d", self._name, sid)
392+
self.sub_listener = self._node.acquire_subject_listener(self, sid)
393+
_logger.info("Listener acquired for '%s' sid=%d", self._name, sid)
381394

382395
def sync_listener(self) -> None:
383396
if self.couplings:
384397
self.ensure_listener()
385398
elif self.sub_listener is not None:
386-
self.sub_listener.close()
399+
self._node.release_subject_listener(self, self.subject_id)
387400
self.sub_listener = None
388401
_logger.info("Listener released for '%s'", self._name)
389402

390403
def release_transport_handles(self) -> None:
404+
sid = self.subject_id
391405
if self.pub_writer is not None:
392-
self.pub_writer.close()
406+
self._node.release_subject_writer(self, sid)
393407
self.pub_writer = None
394408
if self.sub_listener is not None:
395-
self.sub_listener.close()
409+
self._node.release_subject_listener(self, sid)
396410
self.sub_listener = None
397411

398412
def compute_is_implicit(self) -> bool:
@@ -465,6 +479,8 @@ def broadcast_handler(arrival: TransportArrival) -> None:
465479
# Gossip shard state: lazily created per shard.
466480
self.gossip_shard_writers: dict[int, SubjectWriter] = {}
467481
self.gossip_shard_listeners: dict[int, Closable] = {}
482+
self.shared_subject_writers: dict[int, SharedSubjectWriter] = {}
483+
self.shared_subject_listeners: dict[int, SharedSubjectListener] = {}
468484

469485
# Register unicast handler.
470486
transport.unicast_listen(self.on_unicast_arrival)
@@ -600,13 +616,11 @@ def topic_allocate(self, topic: TopicImpl, new_evictions: int, now: float) -> No
600616
elif left_wins(t.lage(now), t.hash, collider.lage(now), collider.hash):
601617
# Our topic wins: take the slot, evict the collider.
602618
t.release_transport_handles()
603-
# Preserve the collider's writer if we can reuse it.
604-
if collider.pub_writer is not None:
605-
t.pub_writer = collider.pub_writer
606-
collider.pub_writer = None
607619
t.evictions = ev
608620
del self.topics_by_subject_id[new_sid]
609621
self.topics_by_subject_id[new_sid] = t
622+
if collider.pub_writer is not None:
623+
t.pub_writer = self.acquire_subject_writer(t, new_sid)
610624
t.sync_listener()
611625
self.schedule_gossip_urgent(t)
612626
# Schedule collider for reallocation.
@@ -708,6 +722,48 @@ def handler(arrival: TransportArrival) -> None:
708722
_logger.debug("Gossip shard writer/listener for sid=%d", shard_sid)
709723
return writer
710724

725+
def acquire_subject_writer(self, topic: TopicImpl, subject_id: int) -> SubjectWriter:
726+
entry = self.shared_subject_writers.get(subject_id)
727+
if entry is None:
728+
entry = SharedSubjectWriter(handle=self.transport.subject_advertise(subject_id))
729+
self.shared_subject_writers[subject_id] = entry
730+
_logger.debug("Shared subject writer created sid=%d", subject_id)
731+
entry.owners.add(topic)
732+
return entry.handle
733+
734+
def release_subject_writer(self, topic: TopicImpl, subject_id: int) -> None:
735+
entry = self.shared_subject_writers.get(subject_id)
736+
if entry is None:
737+
return
738+
entry.owners.discard(topic)
739+
if not entry.owners:
740+
entry.handle.close()
741+
del self.shared_subject_writers[subject_id]
742+
_logger.debug("Shared subject writer released sid=%d", subject_id)
743+
744+
def acquire_subject_listener(self, topic: TopicImpl, subject_id: int) -> Closable:
745+
entry = self.shared_subject_listeners.get(subject_id)
746+
if entry is None:
747+
748+
def handler(arrival: TransportArrival) -> None:
749+
self.on_subject_arrival(subject_id, arrival)
750+
751+
entry = SharedSubjectListener(handle=self.transport.subject_listen(subject_id, handler))
752+
self.shared_subject_listeners[subject_id] = entry
753+
_logger.debug("Shared subject listener created sid=%d", subject_id)
754+
entry.owners.add(topic)
755+
return entry.handle
756+
757+
def release_subject_listener(self, topic: TopicImpl, subject_id: int) -> None:
758+
entry = self.shared_subject_listeners.get(subject_id)
759+
if entry is None:
760+
return
761+
entry.owners.discard(topic)
762+
if not entry.owners:
763+
entry.handle.close()
764+
del self.shared_subject_listeners[subject_id]
765+
_logger.debug("Shared subject listener released sid=%d", subject_id)
766+
711767
def schedule_gossip(self, topic: TopicImpl) -> None:
712768
"""Start periodic gossip for an explicit topic."""
713769
if topic.gossip_task is not None:
@@ -1304,8 +1360,14 @@ def close(self) -> None:
13041360
topic.release_transport_handles()
13051361
self.broadcast_writer.close()
13061362
self.broadcast_listener.close()
1363+
for shared_writer in list(self.shared_subject_writers.values()):
1364+
shared_writer.handle.close()
1365+
self.shared_subject_writers.clear()
1366+
for shared_listener in list(self.shared_subject_listeners.values()):
1367+
shared_listener.handle.close()
1368+
self.shared_subject_listeners.clear()
13071369
for w in self.gossip_shard_writers.values():
13081370
w.close()
1309-
for listener in self.gossip_shard_listeners.values():
1310-
listener.close()
1371+
for gossip_listener in self.gossip_shard_listeners.values():
1372+
gossip_listener.close()
13111373
self.transport.close()

src/pycyphal/_transport.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def subject_listen(self, subject_id: int, handler: Callable[[TransportArrival],
5454
"""
5555
Subscribe to a subject to receive messages from it until the returned closable handle is closed.
5656
The session layer may request at most one listener per subject at any given time, similar to the reference impl.
57+
Duplicate requests for the same subject should raise ValueError.
5758
Unlike the reference implementation, our listeners do not have the extent setting -- the extent mostly matters
5859
for high-reliability/real-time applications; this Python implementation assumes infinite extent.
5960
"""
@@ -63,7 +64,8 @@ def subject_listen(self, subject_id: int, handler: Callable[[TransportArrival],
6364
def subject_advertise(self, subject_id: int) -> SubjectWriter:
6465
"""
6566
Begin sending messages on a subject.
66-
The session layer may request at most one listener per subject at any given time, similar to the reference impl.
67+
The session layer may request at most one writer per subject at any given time, similar to the reference impl.
68+
Duplicate requests for the same subject should raise ValueError.
6769
"""
6870
raise NotImplementedError
6971

src/pycyphal/udp.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,10 @@ async def __call__(self, deadline: Instant, priority: Priority, message: bytes |
352352
_logger.debug("Subject tx done sid=%d tid=%d", self._subject_id, transfer_id)
353353

354354
def close(self) -> None:
355+
if self._closed:
356+
return
355357
self._closed = True
358+
self._transport.remove_subject_writer(self._subject_id, self)
356359
_logger.debug("Subject writer closed for subject %d", self._subject_id)
357360

358361

@@ -441,7 +444,8 @@ def __init__(
441444
self._self_endpoints.add(sock.getsockname()[:2])
442445

443446
# Subject state
444-
self._subject_handlers: dict[int, list[Callable[[TransportArrival], None]]] = {}
447+
self._subject_handlers: dict[int, Callable[[TransportArrival], None]] = {}
448+
self._subject_writers: dict[int, _UDPSubjectWriter] = {}
445449
self._mcast_socks: dict[tuple[int, int], socket.socket] = {}
446450
self._reassemblers: dict[int, _RxReassembler] = {}
447451

@@ -521,23 +525,24 @@ def __repr__(self) -> str:
521525

522526
def remove_subject_listener(self, subject_id: int, handler: Callable[[TransportArrival], None]) -> None:
523527
"""
524-
Remove a handler for a subject; clean up sockets/tasks if no handlers remain.
525-
Internal use only.
528+
Remove the handler for a subject; clean up sockets/tasks if none remains. Internal use only.
526529
"""
527-
handlers = self._subject_handlers.get(subject_id, [])
528-
if handler in handlers:
529-
handlers.remove(handler)
530-
if not handlers:
531-
self._subject_handlers.pop(subject_id, None)
532-
self._reassemblers.pop(subject_id, None)
533-
for i in range(len(self._interfaces)):
534-
key = (subject_id, i)
535-
task = self._mcast_rx_tasks.pop(key, None)
536-
if task is not None:
537-
task.cancel()
538-
sock = self._mcast_socks.pop(key, None)
539-
if sock is not None:
540-
sock.close()
530+
if self._subject_handlers.get(subject_id) is not handler:
531+
return
532+
self._subject_handlers.pop(subject_id, None)
533+
self._reassemblers.pop(subject_id, None)
534+
for i in range(len(self._interfaces)):
535+
key = (subject_id, i)
536+
task = self._mcast_rx_tasks.pop(key, None)
537+
if task is not None:
538+
task.cancel()
539+
sock = self._mcast_socks.pop(key, None)
540+
if sock is not None:
541+
sock.close()
542+
543+
def remove_subject_writer(self, subject_id: int, writer: _UDPSubjectWriter) -> None:
544+
if self._subject_writers.get(subject_id) is writer:
545+
self._subject_writers.pop(subject_id, None)
541546

542547
# -- Async sendto helper --
543548

@@ -558,21 +563,25 @@ def subject_id_modulus(self) -> int:
558563
return self._subject_id_modulus_val
559564

560565
def subject_listen(self, subject_id: int, handler: Callable[[TransportArrival], None]) -> Closable:
561-
if subject_id not in self._subject_handlers:
562-
_logger.info("Subscribing to subject %d", subject_id)
563-
self._subject_handlers[subject_id] = []
564-
for i, iface in enumerate(self._interfaces):
565-
key = (subject_id, i)
566-
sock = self._create_mcast_socket(subject_id, iface)
567-
self._mcast_socks[key] = sock
568-
task = self._loop.create_task(self._mcast_rx_loop(sock, subject_id, i))
569-
self._mcast_rx_tasks[key] = task
570-
self._subject_handlers[subject_id].append(handler)
566+
if subject_id in self._subject_handlers:
567+
raise ValueError(f"Subject {subject_id} already has an active listener")
568+
_logger.info("Subscribing to subject %d", subject_id)
569+
self._subject_handlers[subject_id] = handler
570+
for i, iface in enumerate(self._interfaces):
571+
key = (subject_id, i)
572+
sock = self._create_mcast_socket(subject_id, iface)
573+
self._mcast_socks[key] = sock
574+
task = self._loop.create_task(self._mcast_rx_loop(sock, subject_id, i))
575+
self._mcast_rx_tasks[key] = task
571576
return _UDPSubjectListener(self, subject_id, handler)
572577

573578
def subject_advertise(self, subject_id: int) -> SubjectWriter:
579+
if subject_id in self._subject_writers:
580+
raise ValueError(f"Subject {subject_id} already has an active writer")
574581
_logger.info("Advertising subject %d", subject_id)
575-
return _UDPSubjectWriter(self, subject_id)
582+
writer = _UDPSubjectWriter(self, subject_id)
583+
self._subject_writers[subject_id] = writer
584+
return writer
576585

577586
def unicast_listen(self, handler: Callable[[TransportArrival], None]) -> None:
578587
self._unicast_handler = handler
@@ -627,6 +636,7 @@ def close(self) -> None:
627636
self._mcast_socks.clear()
628637
self._tx_socks.clear()
629638
self._subject_handlers.clear()
639+
self._subject_writers.clear()
630640
self._reassemblers.clear()
631641

632642
# -- Internal async RX loops --
@@ -723,5 +733,6 @@ def _process_subject_datagram(
723733
arrival = TransportArrival(
724734
timestamp=Instant.now(), priority=Priority(priority), remote_id=sender_uid, message=message
725735
)
726-
for handler in self._subject_handlers.get(subject_id, []):
736+
handler = self._subject_handlers.get(subject_id)
737+
if handler is not None:
727738
handler(arrival)

tests/mock_transport.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@ async def __call__(self, deadline: Instant, priority: Priority, message: bytes |
3636
if self.transport.network is not None:
3737
self.transport.network.deliver_subject(self.subject_id, arrival, sender=self.transport)
3838
else:
39-
for handler in self.transport.subject_handlers.get(self.subject_id, []):
39+
handler = self.transport.subject_handlers.get(self.subject_id)
40+
if handler is not None:
4041
handler(arrival)
4142

4243
def close(self) -> None:
44+
if self.closed:
45+
return
4346
self.closed = True
47+
self.transport.remove_subject_writer(self.subject_id, self)
4448

4549

4650
class MockSubjectListener(Closable):
@@ -51,22 +55,22 @@ def __init__(self, transport: MockTransport, subject_id: int, handler: Callable[
5155
self.closed = False
5256

5357
def close(self) -> None:
58+
if self.closed:
59+
return
5460
self.closed = True
55-
handlers = self.transport.subject_handlers.get(self.subject_id, [])
56-
if self.handler in handlers:
57-
handlers.remove(self.handler)
58-
if not handlers:
59-
self.transport.subject_handlers.pop(self.subject_id, None)
61+
self.transport.remove_subject_listener(self.subject_id, self.handler)
6062

6163

6264
class MockTransport(Transport):
6365
def __init__(self, node_id: int = 0, modulus: int = DEFAULT_MODULUS, network: MockNetwork | None = None) -> None:
6466
self.node_id = node_id
6567
self._modulus = modulus
6668
self.network = network
67-
self.subject_handlers: dict[int, list[Callable[[TransportArrival], None]]] = {}
69+
self.subject_handlers: dict[int, Callable[[TransportArrival], None]] = {}
70+
self.subject_listener_creations: dict[int, int] = {}
6871
self.unicast_handler: Callable[[TransportArrival], None] | None = None
6972
self.writers: dict[int, MockSubjectWriter] = {}
73+
self.subject_writer_creations: dict[int, int] = {}
7074
self.unicast_log: list[tuple[int, bytes]] = []
7175
self.closed = False
7276
self.fail_unicast = False
@@ -82,16 +86,28 @@ def subject_id_modulus(self) -> int:
8286
return self._modulus
8387

8488
def subject_listen(self, subject_id: int, handler: Callable[[TransportArrival], None]) -> Closable:
85-
if subject_id not in self.subject_handlers:
86-
self.subject_handlers[subject_id] = []
87-
self.subject_handlers[subject_id].append(handler)
89+
if subject_id in self.subject_handlers:
90+
raise ValueError(f"Subject {subject_id} already has an active listener")
91+
self.subject_handlers[subject_id] = handler
92+
self.subject_listener_creations[subject_id] = self.subject_listener_creations.get(subject_id, 0) + 1
8893
return MockSubjectListener(self, subject_id, handler)
8994

9095
def subject_advertise(self, subject_id: int) -> MockSubjectWriter:
96+
if subject_id in self.writers:
97+
raise ValueError(f"Subject {subject_id} already has an active writer")
9198
writer = MockSubjectWriter(self, subject_id)
9299
self.writers[subject_id] = writer
100+
self.subject_writer_creations[subject_id] = self.subject_writer_creations.get(subject_id, 0) + 1
93101
return writer
94102

103+
def remove_subject_listener(self, subject_id: int, handler: Callable[[TransportArrival], None]) -> None:
104+
if self.subject_handlers.get(subject_id) is handler:
105+
self.subject_handlers.pop(subject_id, None)
106+
107+
def remove_subject_writer(self, subject_id: int, writer: MockSubjectWriter) -> None:
108+
if self.writers.get(subject_id) is writer:
109+
self.writers.pop(subject_id, None)
110+
95111
def unicast_listen(self, handler: Callable[[TransportArrival], None]) -> None:
96112
self.unicast_handler = handler
97113

@@ -118,7 +134,8 @@ def close(self) -> None:
118134
self.closed = True
119135

120136
def deliver_subject(self, subject_id: int, arrival: TransportArrival) -> None:
121-
for handler in self.subject_handlers.get(subject_id, []):
137+
handler = self.subject_handlers.get(subject_id)
138+
if handler is not None:
122139
handler(arrival)
123140

124141
def deliver_unicast(self, arrival: TransportArrival) -> None:

0 commit comments

Comments
 (0)