Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions sdks/python/apache_beam/ml/inference/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ def _solve(self):
logger.error("Solver failed: %s", e)


class QueueTicket:
def __init__(self, priority, ticket_num, tag):
self.priority = priority
self.ticket_num = ticket_num
self.tag = tag
self.wake_event = threading.Event()

def __lt__(self, other):
return (self.priority, self.ticket_num) < (other.priority, other.ticket_num)


class ModelManager:
"""Manages model lifecycles, caching, and resource arbitration.

Expand Down Expand Up @@ -343,6 +354,7 @@ def __init__(
# and also priority for unknown models.
self._wait_queue = []
self._ticket_counter = itertools.count()
self._cancelled_tickets = set()
# TODO: Consider making the wait to be smarter, i.e.
# splitting read/write etc. to avoid potential contention.
self._cv = threading.Condition()
Expand Down Expand Up @@ -417,10 +429,28 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool:
self._cv.wait(timeout=self._lock_timeout_seconds)
return False

def _wake_next_in_queue(self):
if self._wait_queue:
# Clean up cancelled tickets at head of queue
while self._wait_queue and self._wait_queue[
0].ticket_num in self._cancelled_tickets:
self._cancelled_tickets.remove(self._wait_queue[0].ticket_num)
heapq.heappop(self._wait_queue)
next_inline = self._wait_queue[0]
next_inline.wake_event.set()

def _wait_in_queue(self, ticket: QueueTicket):
self._cv.release()
try:
ticket.wake_event.wait(timeout=self._lock_timeout_seconds)
ticket.wake_event.clear()
finally:
self._cv.acquire()

def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
current_priority = 0 if self._estimator.is_unknown(tag) else 1
ticket_num = next(self._ticket_counter)
my_id = object()
my_ticket = QueueTicket(current_priority, ticket_num, tag)

with self._cv:
# FAST PATH: Grab from idle LRU if available
Expand All @@ -439,8 +469,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
current_priority,
len(self._models[tag]),
ticket_num)
heapq.heappush(
self._wait_queue, (current_priority, ticket_num, my_id, tag))
heapq.heappush(self._wait_queue, my_ticket)

est_cost = 0.0
is_unknown = False
Expand All @@ -453,10 +482,11 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
raise RuntimeError(
f"Timeout waiting to acquire model: {tag} "
f"after {wait_time_elapsed:.1f} seconds.")
if not self._wait_queue or self._wait_queue[0][2] is not my_id:
if not self._wait_queue or self._wait_queue[
0].ticket_num != ticket_num:
logger.info(
"Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num)
self._cv.wait(timeout=self._lock_timeout_seconds)
self._wait_in_queue(my_ticket)
continue

# Re-evaluate priority in case model became known during wait
Expand All @@ -467,9 +497,9 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
if current_priority != real_priority:
heapq.heappop(self._wait_queue)
current_priority = real_priority
heapq.heappush(
self._wait_queue, (current_priority, ticket_num, my_id, tag))
self._cv.notify_all()
my_ticket = QueueTicket(current_priority, ticket_num, tag)
heapq.heappush(self._wait_queue, my_ticket)
self._wake_next_in_queue()
continue

# Try grab from LRU again in case model was released during wait
Expand All @@ -494,7 +524,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:
"Waiting due to isolation in progress: tag=%s ticket num%s",
tag,
ticket_num)
self._cv.wait(timeout=self._lock_timeout_seconds)
self._wait_in_queue(my_ticket)
continue

if self.should_spawn_model(tag, ticket_num):
Expand All @@ -508,19 +538,12 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any:

finally:
# Remove self from wait queue once done
if self._wait_queue and self._wait_queue[0][2] is my_id:
if self._wait_queue and self._wait_queue[0].ticket_num == ticket_num:
heapq.heappop(self._wait_queue)
else:
logger.warning(
"Item not at head of wait queue during cleanup"
", this is not expected: tag=%s ticket num=%s",
tag,
ticket_num)
for i, item in enumerate(self._wait_queue):
if item[2] is my_id:
self._wait_queue.pop(i)
heapq.heapify(self._wait_queue)
self._cv.notify_all()
# Marked as cancelled so that we skip when we reach head later
self._cancelled_tickets.add(ticket_num)
self._wake_next_in_queue()

return self._spawn_new_model(tag, loader_func, is_unknown, est_cost)

Expand Down Expand Up @@ -553,6 +576,7 @@ def release_model(self, tag: str, instance: Any):
self._estimator.add_observation(snapshot, peak_during_job)

finally:
self._wake_next_in_queue()
self._cv.notify_all()

def _try_grab_from_lru(self, tag: str) -> Any:
Expand Down Expand Up @@ -596,7 +620,7 @@ def _evict_to_make_space(
# TODO: Also factor in the active counts to avoid thrashing
demand_map = Counter()
for item in self._wait_queue:
demand_map[item[3]] += 1
demand_map[item.tag] += 1

my_demand = demand_map[requesting_tag]
am_i_starving = len(self._models[requesting_tag]) == 0
Expand Down
14 changes: 11 additions & 3 deletions sdks/python/apache_beam/ml/inference/model_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,20 @@ def loader():
def acquire_model_with_timeout():
return self.manager.acquire_model(model_name, loader)

with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(acquire_model_with_timeout)
with ThreadPoolExecutor(max_workers=1000) as executor:
futures = [
executor.submit(acquire_model_with_timeout) for i in range(1000)
]
with self.assertRaises(RuntimeError) as context:
future.result(timeout=5.0)
for future in futures:
future.result()
self.assertIn("Timeout waiting to acquire model", str(context.exception))

# Release the initially acquired model and try to acquire again
# to make sure the manager is still functional
self.manager.release_model(model_name, model_name)
_ = self.manager.acquire_model(model_name, loader)

def test_model_manager_capacity_check(self):
"""
Test that the manager blocks when spawning models exceeds the limit,
Expand Down
Loading