Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ message InstructionRequest {
MonitoringInfosMetadataRequest monitoring_infos = 1005;
HarnessMonitoringInfosRequest harness_monitoring_infos = 1006;
SampleDataRequest sample_data = 1007;
AiWorkerPoolMetadata ai_worker_pool_metadata = 1008;

// DEPRECATED
RegisterRequest register = 1000;
Expand Down Expand Up @@ -529,6 +530,13 @@ message MonitoringInfosMetadataResponse {
map<string, org.apache.beam.model.pipeline.v1.MonitoringInfo> monitoring_info = 1;
}

message AiWorkerPoolMetadata {
// The external IP address of the AI worker pool.
string external_ip = 1;
// The external port of the AI worker pool.
int32 external_port = 2;
}

// Represents a request to the SDK to split a currently active bundle.
message ProcessBundleSplitRequest {
// (Required) A reference to an active process bundle request with the given
Expand Down
Empty file.
101 changes: 101 additions & 0 deletions sdks/python/apache_beam/examples/ratelimit/beam_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions
from apache_beam.runners.worker.sdk_worker import get_ai_worker_pool_metadata

import grpc
import logging
import os
import sys


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'generated_proto')))

from envoy.service.ratelimit.v3 import rls_pb2
from envoy.service.ratelimit.v3 import rls_pb2_grpc
from envoy.extensions.common.ratelimit.v3 import ratelimit_pb2

# Set up logging
logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__)

class GRPCRateLimitClient(beam.DoFn):
"""
A DoFn that makes gRPC calls to an Envoy Rate Limit Service.
"""
def __init__(self):
self._envoy_address = None
self._channel = None
self._stub = None

def setup(self):
"""
Initializes the gRPC channel and stub.
"""
ai_worker_pool_metadata = get_ai_worker_pool_metadata()
self._envoy_address = f"{ai_worker_pool_metadata.external_ip}:{ai_worker_pool_metadata.external_port}"
_LOGGER.info(f"Setting up gRPC client for Envoy at {self._envoy_address}")
self._channel = grpc.insecure_channel(self._envoy_address)
self._stub = rls_pb2_grpc.RateLimitServiceStub(self._channel)

def process(self, element):
client_id = element.get('client_id', 'unknown_client')
request_id = element.get('request_id', 'unknown_request')

_LOGGER.info(f"Processing element: client_id={client_id}, request_id={request_id}")

# Create a RateLimitDescriptor
descriptor = ratelimit_pb2.RateLimitDescriptor()
descriptor.entries.add(key="client_id", value=client_id)
descriptor.entries.add(key="request_id", value=request_id)

# Create a RateLimitRequest
request = rls_pb2.RateLimitRequest(
domain="my_service",
descriptors=[descriptor],
hits_addend=1
)

try:
response = self._stub.ShouldRateLimit(request)
_LOGGER.info(f"RateLimitResponse for client_id={client_id}, request_id={request_id}: {response.overall_code}")
yield {
'client_id': client_id,
'request_id': request_id,
'rate_limit_status': rls_pb2.RateLimitResponse.Code.Name(response.overall_code),
'response_details': str(response)
}
except grpc.RpcError as e:
_LOGGER.error(f"gRPC call failed for client_id={client_id}, request_id={request_id}: {e.details()}")
yield {
'client_id': client_id,
'request_id': request_id,
'rate_limit_status': 'ERROR',
'error_details': e.details()
}

def teardown(self):
if self._channel:
_LOGGER.info("Tearing down gRPC client.")
self._channel.close()

def run():
options = PipelineOptions()
options.view_as(StandardOptions).runner = 'DirectRunner' # Use DirectRunner for local testing

with beam.Pipeline(options=options) as p:
# Sample input data
requests = p | 'CreateRequests' >> beam.Create([
{'client_id': 'user_1', 'request_id': 'req_a'},
{'client_id': 'user_2', 'request_id': 'req_b'},
{'client_id': 'user_1', 'request_id': 'req_c'},
{'client_id': 'user_3', 'request_id': 'req_d'},
])

# Apply the gRPC client DoFn
rate_limit_results = requests | 'CheckRateLimit' >> beam.ParDo(GRPCRateLimitClient())

# Log the results
rate_limit_results | 'LogResults' >> beam.Map(lambda x: _LOGGER.info(f"Result: {x}"))

if __name__ == '__main__':
run()
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/examples/ratelimit/beam_example2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import logging
from apache_beam.runners.worker.sdk_worker import get_ai_worker_pool_metadata

logging.basicConfig(level=logging.INFO)

class PrintFn(beam.DoFn):
def process(self, element):
logging.info(f"Processing element: {element} and worker metadata {get_ai_worker_pool_metadata()}")
yield element

pipeline_options = PipelineOptions()
pipeline = beam.Pipeline(options=pipeline_options)

# Create a PCollection from a list of elements for this batch job.
data = pipeline | 'Create' >> beam.Create([
'Hello',
'World',
'This',
'is',
'a',
'batch',
'example',
])

# Apply the custom DoFn with resource hints.
data | 'PrintWithDoFn' >> beam.ParDo(PrintFn())

result = pipeline.run()
result.wait_until_finish()
3 changes: 3 additions & 0 deletions sdks/python/apache_beam/examples/ratelimit/buf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
version: v1
deps:
- buf.build/envoyproxy/envoy
47 changes: 46 additions & 1 deletion sdks/python/apache_beam/runners/worker/sdk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from apache_beam.utils import thread_pool_executor
from apache_beam.utils.sentinel import Sentinel
from apache_beam.version import __version__ as beam_version

import dataclasses
if TYPE_CHECKING:
from apache_beam.portability.api import endpoints_pb2
from apache_beam.utils.profiler import Profile
Expand Down Expand Up @@ -104,6 +104,44 @@
}]
})

@dataclasses.dataclass
class AiWorkerPoolMetadata:
"""Runtime metadata about AI worker pool resources, such as external IP and
port.

Attributes:
external_ip (str): The external IP address of the AI worker pool.
external_port (int): The external port of the AI worker pool.
"""
external_ip: Optional[str] = None
external_port: Optional[int] = None

@classmethod
def from_proto(cls, proto):
# type: (beam_fn_api_pb2.AiWorkerPoolMetadata) -> AiWorkerPoolMetadata
"""Creates an instance from an AiWorkerPoolMetadata proto."""
return cls(
external_ip=proto.external_ip if proto.external_ip else None,
external_port=proto.external_port if proto.external_port else None)


class _AiMetadataHolder:
"""Singleton holder for AiWorkerPoolMetadata."""
_metadata: Optional[AiWorkerPoolMetadata] = None
_lock = threading.Lock()

@classmethod
def set_metadata(cls, proto):
# type: (beam_fn_api_pb2.AiWorkerPoolMetadata) -> None
with cls._lock:
cls._metadata = AiWorkerPoolMetadata.from_proto(proto)

@classmethod
def get_metadata(cls) -> Optional[AiWorkerPoolMetadata]:
return cls._metadata

def get_ai_worker_pool_metadata() -> Optional[AiWorkerPoolMetadata]:
return _AiMetadataHolder.get_metadata()

class ShortIdCache(object):
""" Cache for MonitoringInfo "short ids"
Expand Down Expand Up @@ -393,6 +431,13 @@ def task():
_LOGGER.debug(
"Currently using %s threads." % len(self._worker_thread_pool._workers))

def _request_ai_worker_pool_metadata(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None
_AiMetadataHolder.set_metadata(request.ai_worker_pool_metadata)
_LOGGER.info("received metadata for AI worker pool: %s", request.ai_worker_pool_metadata)
self._responses.put(
beam_fn_api_pb2.InstructionResponse(instruction_id=request.instruction_id))

def _request_sample_data(self, request):
# type: (beam_fn_api_pb2.InstructionRequest) -> None

Expand Down
Loading