Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT

private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker";
private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber";
private static final String LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME =
"lastCompletedSequenceNumber";
private static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents";

private final AgentPlan agentPlan;
Expand Down Expand Up @@ -191,6 +193,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT

private transient ActionStateStore actionStateStore;
private transient ValueState<Long> sequenceNumberKState;
private transient ValueState<Long> lastCompletedSequenceNumberKState;
private transient ListState<Object> recoveryMarkerOpState;
private transient Map<Long, Map<Object, Long>> checkpointIdToSeqNums;

Expand Down Expand Up @@ -288,6 +291,11 @@ public void open() throws Exception {
.getState(
new ValueStateDescriptor<>(
MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class));
lastCompletedSequenceNumberKState =
getRuntimeContext()
.getState(
new ValueStateDescriptor<>(
LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, Long.class));

// init agent processing related state
actionTasksKState =
Expand Down Expand Up @@ -578,11 +586,11 @@ private void processActionTaskForKey(Object key) throws Exception {
if (currentInputEventFinished) {
// Clean up sensory memory when a single run finished.
actionTask.getRunnerContext().clearSensoryMemory();
lastCompletedSequenceNumberKState.update(sequenceNumber);

// Once all sub-events and actions related to the current InputEvent are completed,
// we can proceed to process the next InputEvent.
int removedCount = removeFromListState(currentProcessingKeysOpState, key);
maybePruneState(key, sequenceNumber);
checkState(
removedCount == 1,
"Current processing key count for key "
Expand Down Expand Up @@ -789,8 +797,14 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
.applyToAllKeys(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class),
(key, state) -> keyToSeqNum.put(key, state.value()));
new ValueStateDescriptor<>(
LAST_COMPLETED_SEQUENCE_NUMBER_STATE_NAME, Long.class),
(key, state) -> {
Long completedSequenceNumber = state.value();
if (completedSequenceNumber != null) {
keyToSeqNum.put(key, completedSequenceNumber);
}
});
checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum);

super.snapshotState(context);
Expand Down Expand Up @@ -1067,12 +1081,6 @@ public void persist(
}
}

private void maybePruneState(Object key, long sequenceNum) throws Exception {
if (actionStateStore != null) {
actionStateStore.pruneState(key, sequenceNum);
}
}

private void processEligibleWatermarks() throws Exception {
Watermark mark = keySegmentQueue.popOldestWatermark();
while (mark != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,62 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)),
}
}

@Test
void testDoesNotPruneBeforeCheckpointComplete() throws Exception {
AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
RecordingActionStateStore actionStateStore = new RecordingActionStateStore();

try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
new ActionExecutionOperatorFactory<>(
agentPlanWithStateStore, true, actionStateStore),
(KeySelector<Long, Long>) value -> value,
TypeInformation.of(Long.class))) {
testHarness.open();
ActionExecutionOperator<Long, Object> operator =
(ActionExecutionOperator<Long, Object>) testHarness.getOperator();

testHarness.processElement(new StreamRecord<>(5L));
operator.waitInFlightEventsFinished();
assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();

testHarness.snapshot(1L, 1L);
assertThat(actionStateStore.getPrunedSeqNums()).isEmpty();
testHarness.notifyOfCompletedCheckpoint(1L);

assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
}
}

@Test
void testDoesNotPruneSeqsInFlight() throws Exception {
AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
RecordingActionStateStore actionStateStore = new RecordingActionStateStore();

try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
new ActionExecutionOperatorFactory<>(
agentPlanWithStateStore, true, actionStateStore),
(KeySelector<Long, Long>) value -> value,
TypeInformation.of(Long.class))) {
testHarness.open();
ActionExecutionOperator<Long, Object> operator =
(ActionExecutionOperator<Long, Object>) testHarness.getOperator();

testHarness.processElement(new StreamRecord<>(5L));
operator.waitInFlightEventsFinished();
actionStateStore.clearPruneCalls();

testHarness.processElement(new StreamRecord<>(5L));
assertThat(testHarness.getTaskMailbox().size()).isEqualTo(1);

testHarness.snapshot(1L, 1L);
testHarness.notifyOfCompletedCheckpoint(1L);

assertThat(actionStateStore.getPrunedSeqNums()).containsExactly(0L);
}
}

@Test
void testEventLogBaseDirFromAgentConfig() throws Exception {
String baseLogDir = "/tmp/flink-agents-test";
Expand Down Expand Up @@ -461,7 +517,7 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)),
}

@Test
void testActionStateStoreCleanupAfterOutputEvent() throws Exception {
void testActionStateStoreCleanupAfterCheckpointComplete() throws Exception {
AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);

try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness =
Expand Down Expand Up @@ -496,10 +552,66 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(true)),
actionStateStoreField.setAccessible(true);
InMemoryActionStateStore actionStateStore =
(InMemoryActionStateStore) actionStateStoreField.get(operator);
assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();

testHarness.snapshot(1L, 1L);
testHarness.notifyOfCompletedCheckpoint(1L);

assertThat(actionStateStore.getKeyedActionStates()).isEmpty();
}
}

@Test
void testEarlierCheckpointReplayKeepsDurableState() throws Exception {
AgentPlan agentPlan = TestAgent.getDurableSyncAgentPlan();
InMemoryActionStateStore actionStateStore = new InMemoryActionStateStore(true);
OperatorSubtaskState snapshot;

TestAgent.DURABLE_CALL_COUNTER.set(0);

try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore),
(KeySelector<Long, Long>) value -> value,
TypeInformation.of(Long.class))) {
testHarness.open();
ActionExecutionOperator<Long, Object> operator =
(ActionExecutionOperator<Long, Object>) testHarness.getOperator();

// Simulate failure recovery from a checkpoint taken before this input was processed.
snapshot = testHarness.snapshot(1L, 1L);

testHarness.processElement(new StreamRecord<>(7L));
operator.waitInFlightEventsFinished();

assertThat(TestAgent.DURABLE_CALL_COUNTER.get()).isEqualTo(1);
assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
}

try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(
new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore),
(KeySelector<Long, Long>) value -> value,
TypeInformation.of(Long.class))) {
testHarness.initializeState(snapshot);
testHarness.open();
ActionExecutionOperator<Long, Object> operator =
(ActionExecutionOperator<Long, Object>) testHarness.getOperator();

// Replay the same input after restoring from the earlier checkpoint.
testHarness.processElement(new StreamRecord<>(7L));
operator.waitInFlightEventsFinished();

List<StreamRecord<Object>> recordOutput =
(List<StreamRecord<Object>>) testHarness.getRecordOutput();
assertThat(recordOutput).hasSize(1);
assertThat(recordOutput.get(0).getValue()).isEqualTo(21L);
assertThat(TestAgent.DURABLE_CALL_COUNTER.get())
.as("Durable supplier should not be re-executed during replay")
.isEqualTo(1);
}
}

@Test
void testActionStateStoreReplayIncurNoFunctionCall() throws Exception {
AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false);
Expand Down Expand Up @@ -1524,6 +1636,27 @@ public static AgentPlan getDurableAsyncExceptionAgentPlan() {
}
}

private static class RecordingActionStateStore extends InMemoryActionStateStore {
private final List<Long> prunedSeqNums = new java.util.ArrayList<>();

private RecordingActionStateStore() {
super(false);
}

@Override
public void pruneState(Object key, long seqNum) {
prunedSeqNums.add(seqNum);
}

private void clearPruneCalls() {
prunedSeqNums.clear();
}

private List<Long> getPrunedSeqNums() {
return prunedSeqNums;
}
}

private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int expectedSize)
throws Exception {
assertThat(mailbox.size()).isEqualTo(expectedSize);
Expand Down
Loading