Skip to content

Commit 2c8e4b5

Browse files
committed
ref: rename get_lock to `get_counter_lock in SlotRegister
1 parent 6910c0c commit 2c8e4b5

2 files changed

Lines changed: 18 additions & 10 deletions

File tree

CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
- enh: introduce logging with queues in `UniversalWorker`
66
- ref: `SegmentTorchSTO` is now always visible, even if CUDA is not available
77
This was necessary, because we want to lazy-load torch.
8+
- ref: rename `get_lock` to ``get_counter_lock` in `SlotRegister`
9+
- ref: generalized interface for counters in `SlotRegister`
810
0.28.3
911
- fix: sort internal basins before file-based basins
1012
0.28.2

src/dcnum/logic/slot_register.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ def __init__(self,
2828
self.num_chunks = data.image.num_chunks
2929
self._slots = []
3030

31-
self._chunks_loaded = mp_spawn.Value("Q", 0)
31+
# Counters are created with recursive locks, which means that the
32+
# same process may acquire multiple locks on the object, and only
33+
# after releasing all of them, may the lock be acquired by another
34+
# process.
35+
self._counters = {
36+
"chunks_loaded": mp_spawn.Value("Q", 0)
37+
}
3238

3339
self._state = mp_spawn.Value("u", "w")
3440

@@ -57,15 +63,15 @@ def __len__(self):
5763

5864
@property
5965
def chunks_loaded(self):
60-
"""A multiprocessing value counting the number of chunks loaded
66+
"""A process-safe counter for the number of chunks loaded
6167
62-
This number increments as `ChunkSlot.task_load_all` is called.
68+
This number increments as `SlotRegister.task_load_all` is called.
6369
"""
64-
return self._chunks_loaded.value
70+
return self._counters["chunks_loaded"].value
6571

6672
@chunks_loaded.setter
6773
def chunks_loaded(self, value):
68-
self._chunks_loaded.value = value
74+
self._counters["chunks_loaded"].value = value
6975

7076
@property
7177
def slots(self):
@@ -110,11 +116,11 @@ def find_slot(self, state: str, chunk: int = None) -> ChunkSlot | None:
110116
# fallback to nothing found
111117
return None
112118

113-
def get_lock(self, name):
114-
if name == "chunks_loaded":
115-
return self._chunks_loaded.get_lock()
119+
def get_counter_lock(self, name):
120+
if name in self._counters:
121+
return self._counters[name].get_lock()
116122
else:
117-
raise KeyError(f"No lock defined for {name}")
123+
raise KeyError(f"No counter lock defined for {name}")
118124

119125
def get_time(self, method_name):
120126
"""Return accumulative time for the given method
@@ -167,7 +173,7 @@ def task_load_all(self, logger: logging.Logger = None) -> bool:
167173
Whether data were loaded into memory
168174
"""
169175
did_something = False
170-
lock = self.get_lock("chunks_loaded")
176+
lock = self.get_counter_lock("chunks_loaded")
171177
has_lock = lock.acquire(block=False)
172178
if has_lock and self.chunks_loaded < self.num_chunks:
173179
try:

0 commit comments

Comments
 (0)