Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.functions.async.AsyncRetryStrategy;
import org.apache.flink.streaming.api.operators.async.AsyncBatchWaitOperatorFactory;
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperator;
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
import org.apache.flink.util.Preconditions;
Expand Down Expand Up @@ -319,4 +321,51 @@ public static <IN, OUT> SingleOutputStreamOperator<OUT> orderedWaitWithRetry(
OutputMode.ORDERED,
asyncRetryStrategy);
}

// ================================================================================
// Batch Async Operations
// ================================================================================

/**
* Adds an AsyncBatchWaitOperator to process elements in batches. The order of output stream
* records may be reordered (unordered mode).
*
* <p>This method is particularly useful for high-latency inference workloads where batching can
* significantly improve throughput, such as machine learning model inference.
*
* <p>The operator buffers incoming elements and triggers the async batch function when the
* buffer reaches {@code maxBatchSize}. Remaining elements are flushed when the input ends.
*
* @param in Input {@link DataStream}
* @param func {@link AsyncBatchFunction} to process batches of elements
* @param maxBatchSize Maximum number of elements to batch before triggering async invocation
* @param <IN> Type of input record
* @param <OUT> Type of output record
* @return A new {@link SingleOutputStreamOperator}
*/
public static <IN, OUT> SingleOutputStreamOperator<OUT> unorderedWaitBatch(
DataStream<IN> in, AsyncBatchFunction<IN, OUT> func, int maxBatchSize) {
Preconditions.checkArgument(maxBatchSize > 0, "maxBatchSize must be greater than 0");

TypeInformation<OUT> outTypeInfo =
TypeExtractor.getUnaryOperatorReturnType(
func,
AsyncBatchFunction.class,
0,
1,
new int[] {1, 0},
in.getType(),
Utils.getCallLocationName(),
true);

// create transform
AsyncBatchWaitOperatorFactory<IN, OUT> operatorFactory =
new AsyncBatchWaitOperatorFactory<>(
in.getExecutionEnvironment().clean(func), maxBatchSize);

return in.transform("async batch wait operator", outTypeInfo, operatorFactory);
}

// TODO: Add orderedWaitBatch in follow-up PR
// TODO: Add time-based batching support in follow-up PR
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.functions.async;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.functions.Function;

import java.io.Serializable;
import java.util.List;

/**
* A function to trigger Async I/O operations in batches.
*
* <p>For each batch of inputs, an async I/O operation can be triggered via {@link
* #asyncInvokeBatch}, and once it has been done, the results can be collected by calling {@link
* ResultFuture#complete}. This is particularly useful for high-latency inference workloads where
* batching can significantly improve throughput.
*
* <p>Unlike {@link AsyncFunction} which processes one element at a time, this interface allows
* processing multiple elements together, which is beneficial for scenarios like:
*
* <ul>
* <li>Machine learning model inference where batching improves GPU utilization
* <li>External service calls that support batch APIs
* <li>Database queries that can be batched for efficiency
* </ul>
*
* <p>Example usage:
*
* <pre>{@code
* public class BatchInferenceFunction implements AsyncBatchFunction<String, String> {
*
* public void asyncInvokeBatch(List<String> inputs, ResultFuture<String> resultFuture) {
* // Submit batch inference request
* CompletableFuture.supplyAsync(() -> {
* List<String> results = modelService.batchInference(inputs);
* return results;
* }).thenAccept(results -> resultFuture.complete(results));
* }
* }
* }</pre>
*
* @param <IN> The type of the input elements.
* @param <OUT> The type of the returned elements.
*/
@PublicEvolving
public interface AsyncBatchFunction<IN, OUT> extends Function, Serializable {

/**
* Trigger async operation for a batch of stream inputs.
*
* <p>The implementation should process all inputs in the batch and complete the result future
* with all corresponding outputs. The number of outputs does not need to match the number of
* inputs - it depends on the specific use case.
*
* @param inputs a batch of elements coming from upstream tasks
* @param resultFuture to be completed with the result data for the entire batch
* @throws Exception in case of a user code error. An exception will make the task fail and
* trigger fail-over process.
*/
void asyncInvokeBatch(List<IN> inputs, ResultFuture<OUT> resultFuture) throws Exception;

// TODO: Add timeout handling in follow-up PR
// TODO: Add open/close lifecycle methods in follow-up PR
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.operators.async;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction;
import org.apache.flink.streaming.api.functions.async.CollectionSupplier;
import org.apache.flink.streaming.api.functions.async.ResultFuture;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Preconditions;

import javax.annotation.Nonnull;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* The {@link AsyncBatchWaitOperator} batches incoming stream records and invokes the {@link
* AsyncBatchFunction} when the batch size reaches the configured maximum.
*
* <p>This operator implements unordered semantics only - results are emitted as soon as they are
* available, regardless of input order. This is suitable for AI inference workloads where order
* does not matter.
*
* <p>Key behaviors:
*
* <ul>
* <li>Buffer incoming records until batch size is reached
* <li>Flush remaining records when end of input is signaled
* <li>Emit all results from the batch function to downstream
* </ul>
*
* <p>This is a minimal implementation for the first PR. Future enhancements may include:
*
* <ul>
* <li>Ordered mode support
* <li>Time-based batching with timers
* <li>Timeout handling
* <li>Retry logic
* <li>Metrics
* </ul>
*
* @param <IN> Input type for the operator.
* @param <OUT> Output type for the operator.
*/
@Internal
public class AsyncBatchWaitOperator<IN, OUT> extends AbstractStreamOperator<OUT>
implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {

private static final long serialVersionUID = 1L;

/** The async batch function to invoke. */
private final AsyncBatchFunction<IN, OUT> asyncBatchFunction;

/** Maximum batch size before triggering async invocation. */
private final int maxBatchSize;

/** Buffer for incoming stream records. */
private transient List<IN> buffer;

/** Mailbox executor for processing async results on the main thread. */
private final transient MailboxExecutor mailboxExecutor;

/** Counter for in-flight async operations. */
private transient int inFlightCount;

public AsyncBatchWaitOperator(
@Nonnull StreamOperatorParameters<OUT> parameters,
@Nonnull AsyncBatchFunction<IN, OUT> asyncBatchFunction,
int maxBatchSize,
@Nonnull MailboxExecutor mailboxExecutor) {
Preconditions.checkArgument(maxBatchSize > 0, "maxBatchSize must be greater than 0");
this.asyncBatchFunction = Preconditions.checkNotNull(asyncBatchFunction);
this.maxBatchSize = maxBatchSize;
this.mailboxExecutor = Preconditions.checkNotNull(mailboxExecutor);

// Setup the operator using parameters
setup(parameters.getContainingTask(), parameters.getStreamConfig(), parameters.getOutput());
}

@Override
public void open() throws Exception {
super.open();
this.buffer = new ArrayList<>(maxBatchSize);
this.inFlightCount = 0;
}

@Override
public void processElement(StreamRecord<IN> element) throws Exception {
buffer.add(element.getValue());

if (buffer.size() >= maxBatchSize) {
flushBuffer();
}
}

/** Flush the current buffer by invoking the async batch function. */
private void flushBuffer() throws Exception {
if (buffer.isEmpty()) {
return;
}

// Create a copy of the buffer and clear it for new incoming elements
List<IN> batch = new ArrayList<>(buffer);
buffer.clear();

// Increment in-flight counter
inFlightCount++;

// Create result handler for this batch
BatchResultHandler resultHandler = new BatchResultHandler();

// Invoke the async batch function
asyncBatchFunction.asyncInvokeBatch(batch, resultHandler);
}

@Override
public void endInput() throws Exception {
// Flush any remaining elements in the buffer
flushBuffer();

// Wait for all in-flight async operations to complete
while (inFlightCount > 0) {
mailboxExecutor.yield();
}
}

@Override
public void close() throws Exception {
super.close();
}

/** Returns the current buffer size. Visible for testing. */
int getBufferSize() {
return buffer != null ? buffer.size() : 0;
}

/** A handler for the results of a batch async invocation. */
private class BatchResultHandler implements ResultFuture<OUT> {

/** Guard against multiple completions. */
private final AtomicBoolean completed = new AtomicBoolean(false);

@Override
public void complete(Collection<OUT> results) {
Preconditions.checkNotNull(
results, "Results must not be null, use empty collection to emit nothing");

if (!completed.compareAndSet(false, true)) {
return;
}

// Process results in the mailbox thread
mailboxExecutor.execute(
() -> processResults(results), "AsyncBatchWaitOperator#processResults");
}

@Override
public void completeExceptionally(Throwable error) {
if (!completed.compareAndSet(false, true)) {
return;
}

// Signal failure through the containing task
getContainingTask()
.getEnvironment()
.failExternally(new Exception("Async batch operation failed.", error));

// Decrement in-flight counter in mailbox thread
mailboxExecutor.execute(
() -> inFlightCount--, "AsyncBatchWaitOperator#decrementInFlight");
}

@Override
public void complete(CollectionSupplier<OUT> supplier) {
Preconditions.checkNotNull(
supplier, "Supplier must not be null, return empty collection to emit nothing");

if (!completed.compareAndSet(false, true)) {
return;
}

mailboxExecutor.execute(
() -> {
try {
processResults(supplier.get());
} catch (Throwable t) {
getContainingTask()
.getEnvironment()
.failExternally(
new Exception("Async batch operation failed.", t));
inFlightCount--;
}
},
"AsyncBatchWaitOperator#processResultsFromSupplier");
}

private void processResults(Collection<OUT> results) {
// Emit all results downstream
for (OUT result : results) {
output.collect(new StreamRecord<>(result));
}
// Decrement in-flight counter
inFlightCount--;
}
}
}
Loading