Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
public class AsyncWorkflowPollTask
implements AsyncPoller.PollTaskAsync<WorkflowTask>, DisableNormalPolling {
private static final Logger log = LoggerFactory.getLogger(AsyncWorkflowPollTask.class);
private final TrackingSlotSupplier<?> slotSupplier;
private final TrackingSlotSupplier<WorkflowSlotInfo> slotSupplier;
private final WorkflowServiceStubs service;
private final Scope metricsScope;
private final Scope pollerMetricScope;
Expand Down Expand Up @@ -150,6 +150,7 @@ public CompletableFuture<WorkflowTask> poll(SlotPermit permit)
.inc(1);
return null;
}
slotSupplier.markSlotUsed(new WorkflowSlotInfo(r, pollRequest), permit);
pollerMetricScope
.counter(MetricsType.WORKFLOW_TASK_QUEUE_POLL_SUCCEED_COUNTER)
.inc(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import com.google.common.util.concurrent.Futures;
import com.google.protobuf.ByteString;
import com.uber.m3.tally.RootScopeBuilder;
import com.uber.m3.tally.Scope;
Expand All @@ -17,6 +18,8 @@
import io.temporal.serviceclient.WorkflowServiceStubs;
import io.temporal.worker.tuning.*;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -124,4 +127,64 @@ public void supplierIsCalledAppropriately() {
assertEquals(1, trackingSS.getUsedSlots().size());
}
}

@Test
public void asyncPollerSupplierIsCalledAppropriately() throws Exception {
WorkflowServiceStubs client = mock(WorkflowServiceStubs.class);
when(client.getServerCapabilities())
.thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build());

WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub =
mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class);
when(client.futureStub()).thenReturn(futureStub);
when(futureStub.withOption(any(), any())).thenReturn(futureStub);

SlotSupplier<WorkflowSlotInfo> mockSupplier = mock(SlotSupplier.class);
Scope metricsScope =
new RootScopeBuilder()
.reporter(reporter)
.reportEvery(com.uber.m3.util.Duration.ofMillis(1));
TrackingSlotSupplier<WorkflowSlotInfo> trackingSS =
new TrackingSlotSupplier<>(mockSupplier, metricsScope);

PollWorkflowTaskQueueResponse pollResponse =
PollWorkflowTaskQueueResponse.newBuilder()
.setTaskToken(ByteString.copyFrom("token", UTF_8))
.setWorkflowExecution(
WorkflowExecution.newBuilder().setWorkflowId(WORKFLOW_ID).setRunId(RUN_ID).build())
.setWorkflowType(WorkflowType.newBuilder().setName(WORKFLOW_TYPE).build())
.build();

if (throwOnPoll) {
when(futureStub.pollWorkflowTaskQueue(any()))
.thenReturn(Futures.immediateFailedFuture(new RuntimeException("Poll failed")));
} else {
when(futureStub.pollWorkflowTaskQueue(any()))
.thenReturn(Futures.immediateFuture(pollResponse));
}

AsyncWorkflowPollTask pollTask =
new AsyncWorkflowPollTask(
client,
"default",
TASK_QUEUE,
null,
"",
new WorkerVersioningOptions("", false, null),
trackingSS,
metricsScope,
() -> GetSystemInfoResponse.Capabilities.newBuilder().build());

SlotPermit permit = new SlotPermit();

CompletableFuture<WorkflowTask> future = pollTask.poll(permit);
if (throwOnPoll) {
assertThrows(ExecutionException.class, future::get);
assertEquals(0, trackingSS.getUsedSlots().size());
} else {
WorkflowTask task = future.get();
assertNotNull(task);
assertEquals(1, trackingSS.getUsedSlots().size());
}
}
}