Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions sdks/python/apache_beam/transforms/async_dofn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@

from __future__ import absolute_import

import asyncio
import inspect
import logging
import random
import threading
import uuid
from collections.abc import AsyncIterable
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from math import floor
from threading import RLock
from time import sleep
from time import time
from types import GeneratorType
from typing import Optional

import apache_beam as beam
from apache_beam import TimeDomain
Expand Down Expand Up @@ -60,6 +66,9 @@ class AsyncWrapper(beam.DoFn):
[coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()]))
# The below items are one per dofn (not instance) so are maps of UUID to
# value.
_event_loop: Optional[asyncio.AbstractEventLoop] = None
_event_loop_thread: Optional[threading.Thread] = None
_loop_started: Optional[threading.Event] = None
_processing_elements = {}
_items_in_buffer = {}
_pool = {}
Expand All @@ -78,6 +87,7 @@ def __init__(
timeout=1,
max_wait_time=0.5,
id_fn=None,
use_asyncio=False,
):
"""Wraps the sync_fn to create an asynchronous version.

Expand All @@ -104,6 +114,10 @@ def __init__(
schedule an item. Used in testing to ensure timeouts are met.
id_fn: A function that returns a hashable object from an element. This
will be used to track items instead of the element's default hash.
use_asyncio: If true, use asyncio and coroutines to process items. If
false, use ThreadPoolExecutor. Use asyncio when the work being done
is not CPU intensive and heavily waits on network or IO which can
benefit from higher parallelism.
"""
self._sync_fn = sync_fn
self._uuid = uuid.uuid4().hex
Expand All @@ -112,6 +126,7 @@ def __init__(
self._max_wait_time = max_wait_time
self._timer_frequency = callback_frequency
self._id_fn = id_fn or (lambda x: x)
self._use_asyncio = use_asyncio
if max_items_to_buffer is None:
self._max_items_to_buffer = max(parallelism * 2, 10)
else:
Expand All @@ -126,11 +141,33 @@ def __init__(
def initialize_pool(parallelism):
return lambda: ThreadPoolExecutor(max_workers=parallelism)

@staticmethod
def _run_event_loop():
"""Sets up and runs the asyncio event loop in a background thread."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
AsyncWrapper._event_loop = loop
AsyncWrapper._loop_started.set()
loop.run_forever()
loop.close()

@staticmethod
def reset_state():
for pool in AsyncWrapper._pool.values():
pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown(
wait=True, cancel_futures=True)
with AsyncWrapper._lock:
if AsyncWrapper._event_loop:
AsyncWrapper._event_loop.call_soon_threadsafe(
AsyncWrapper._event_loop.stop)
if AsyncWrapper._event_loop_thread:
AsyncWrapper._event_loop_thread.join()

AsyncWrapper._event_loop = None
AsyncWrapper._event_loop_thread = None
if AsyncWrapper._loop_started is not None:
AsyncWrapper._loop_started.clear()

for pool in AsyncWrapper._pool.values():
pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown(
wait=True, cancel_futures=True)
with AsyncWrapper._lock:
AsyncWrapper._pool = {}
AsyncWrapper._processing_elements = {}
Expand All @@ -140,6 +177,13 @@ def setup(self):
"""Forwards to the wrapped dofn's setup method."""
self._sync_fn.setup()
with AsyncWrapper._lock:
if self._use_asyncio and AsyncWrapper._event_loop_thread is None:
AsyncWrapper._loop_started = threading.Event()
AsyncWrapper._event_loop_thread = threading.Thread(
target=AsyncWrapper._run_event_loop, daemon=True)
AsyncWrapper._event_loop_thread.start()
AsyncWrapper._loop_started.wait()

if not self._uuid in AsyncWrapper._pool:
AsyncWrapper._pool[self._uuid] = Shared()
AsyncWrapper._processing_elements[self._uuid] = {}
Expand Down Expand Up @@ -187,9 +231,41 @@ def sync_fn_process(self, element, *args, **kwargs):
to_return.append(x)
for x in bundle_result:
to_return.append(x)

return to_return

async def async_fn_process(self, element, *args, **kwargs):
"""Makes the call to the wrapped dofn's start_bundle, process
and finish_bundle methods for asynchronous DoFns.

Args:
element: The element to process.
*args: Any additional arguments to pass to the wrapped dofn's process
method.
**kwargs: Any additional keyword arguments to pass to the wrapped dofn's
process method.

Returns:
A list of elements produced by the input element.
"""
async def _collect(result):
if result is None:
return []
if inspect.isawaitable(result):
result = await result
if isinstance(result, AsyncIterable):
return [item async for item in result]
if isinstance(result,
(GeneratorType, Iterable)) and not isinstance(result,
(str, bytes)):
return list(result)
return [result]

self._sync_fn.start_bundle()
process_result = await _collect(
self._sync_fn.process(element, *args, **kwargs))
bundle_result = await _collect(self._sync_fn.finish_bundle())
return process_result + bundle_result

def decrement_items_in_buffer(self, future):
with AsyncWrapper._lock:
AsyncWrapper._items_in_buffer[self._uuid] -= 1
Expand All @@ -214,10 +290,16 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
logging.info('item %s already in processing elements', element)
return True
if self.accepting_items() or ignore_buffer:
result = AsyncWrapper._pool[self._uuid].acquire(
AsyncWrapper.initialize_pool(self._parallelism)).submit(
lambda: self.sync_fn_process(element, *args, **kwargs),
)
if self._use_asyncio:
result = asyncio.run_coroutine_threadsafe(
self.async_fn_process(element, *args, **kwargs),
AsyncWrapper._event_loop,
)
else:
result = AsyncWrapper._pool[self._uuid].acquire(
AsyncWrapper.initialize_pool(self._parallelism)).submit(
lambda: self.sync_fn_process(element, *args, **kwargs),
)
result.add_done_callback(self.decrement_items_in_buffer)
AsyncWrapper._processing_elements[self._uuid][element_id] = (
element, result)
Expand Down
49 changes: 32 additions & 17 deletions sdks/python/apache_beam/transforms/async_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from concurrent.futures import ThreadPoolExecutor
from threading import Lock

from parameterized import parameterized_class

import apache_beam as beam
import apache_beam.transforms.async_dofn as async_lib

Expand Down Expand Up @@ -62,7 +64,7 @@ class FakeBagState:
def __init__(self, items):
self.items = items
# Normally SE would have a lock on the BT row protecting this from multiple
# updates. Here without SE we must lock ourselvs.
# updates. Here without SE we must lock ourselves.
self.lock = Lock()

def add(self, item):
Expand All @@ -86,6 +88,14 @@ def set(self, time):
self.time = time


@parameterized_class([
{
"use_asyncio": True
},
{
"use_asyncio": False
},
])
class AsyncTest(unittest.TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -132,7 +142,8 @@ def __eq__(self, other):
return self.element_id == other.element_id

dofn = BasicDofn()
async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id)
async_dofn = async_lib.AsyncWrapper(
dofn, id_fn=lambda x: x.element_id, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
Expand All @@ -156,7 +167,7 @@ def __eq__(self, other):
def test_basic(self):
# Setup an async dofn and send a message in to process.
dofn = BasicDofn()
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
Expand All @@ -181,9 +192,9 @@ def test_basic(self):
self.assertEqual(fake_bag_state.items, [])

def test_multi_key(self):
# Send in two messages with different keys..
# Send in two messages with different keys.
dofn = BasicDofn()
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state_key1 = FakeBagState([])
fake_bag_state_key2 = FakeBagState([])
Expand Down Expand Up @@ -211,7 +222,7 @@ def test_multi_key(self):
def test_long_item(self):
# Test that everything still works with a long running time for the dofn.
dofn = BasicDofn(sleep_time=5)
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
Expand All @@ -231,10 +242,10 @@ def test_long_item(self):
self.assertEqual(fake_bag_state.items, [])

def test_lost_item(self):
# Setup an element in the bag stat thats not in processing state.
# Setup an element in the bag state that's not in processing state.
# The async dofn should reschedule this element.
dofn = BasicDofn()
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_timer = FakeTimer(0)
msg = ('key1', 1)
Expand All @@ -250,9 +261,9 @@ def test_lost_item(self):
def test_cancelled_item(self):
# Test that an item gets removed for processing and does not get output when
# it is not present in the bag state. Either this item moved or a commit
# failed making the local state and bag stat inconsistent.
# failed making the local state and bag state inconsistent.
dofn = BasicDofn()
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
msg = ('key1', 1)
msg2 = ('key1', 2)
Expand All @@ -272,7 +283,7 @@ def test_multi_element_dofn(self):
# Test that async works when a dofn produces multiple elements in process
# and finish_bundle.
dofn = MultiElementDoFn()
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
Expand All @@ -289,7 +300,7 @@ def test_duplicates(self):
# Test that async will produce a single output when a given input is sent
# multiple times.
dofn = BasicDofn(5)
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
Expand All @@ -310,7 +321,7 @@ def test_slow_duplicates(self):
# Test that async will produce a single output when a given input is sent
# multiple times.
dofn = BasicDofn(5)
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
Expand All @@ -335,7 +346,7 @@ def test_slow_duplicates(self):
def test_buffer_count(self):
# Test that the buffer count is correctly incremented when adding items.
dofn = BasicDofn(5)
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
msg = ('key1', 1)
fake_timer = FakeTimer(0)
Expand All @@ -353,7 +364,10 @@ def test_buffer_stops_accepting_items(self):
# Test that the buffer stops accepting items when it is full.
dofn = BasicDofn(5)
async_dofn = async_lib.AsyncWrapper(
dofn, parallelism=1, max_items_to_buffer=5)
dofn,
parallelism=1,
max_items_to_buffer=5,
use_asyncio=self.use_asyncio)
async_dofn.setup()
fake_timer = FakeTimer(0)
fake_bag_state = FakeBagState([])
Expand Down Expand Up @@ -391,7 +405,7 @@ def add_item(i):

def test_buffer_with_cancellation(self):
dofn = BasicDofn(3)
async_dofn = async_lib.AsyncWrapper(dofn)
async_dofn = async_lib.AsyncWrapper(dofn, use_asyncio=self.use_asyncio)
async_dofn.setup()
msg = ('key1', 1)
msg2 = ('key1', 2)
Expand Down Expand Up @@ -423,7 +437,8 @@ def test_load_correctness(self):
# Test AsyncDofn over heavy load.
dofn = BasicDofn(1)
max_sleep = 10
async_dofn = async_lib.AsyncWrapper(dofn, max_wait_time=max_sleep)
async_dofn = async_lib.AsyncWrapper(
dofn, max_wait_time=max_sleep, use_asyncio=self.use_asyncio)
async_dofn.setup()
bag_states = {}
timers = {}
Expand Down
Loading