Skip to content

Commit e14bb47

Browse files
Stream closure - cancel underlying stream
Signed-off-by: Danilo Najkov <danilo.najkov@databricks.com>
1 parent e9c5934 commit e14bb47

File tree

3 files changed

+69
-12
lines changed

3 files changed

+69
-12
lines changed

src/main/java/com/databricks/zerobus/ZerobusSdk.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.concurrent.Executors;
1111
import java.util.concurrent.ThreadFactory;
1212
import java.util.concurrent.atomic.AtomicInteger;
13+
import java.util.function.Supplier;
1314
import org.slf4j.Logger;
1415
import org.slf4j.LoggerFactory;
1516

@@ -160,7 +161,7 @@ public <RecordType extends Message> CompletableFuture<ZerobusStream<RecordType>>
160161
logger.debug("Creating stream for table: " + tableProperties.getTableName());
161162

162163
// Create a token supplier that generates a fresh token for each gRPC request
163-
java.util.function.Supplier<String> tokenSupplier =
164+
Supplier<String> tokenSupplier =
164165
() -> {
165166
try {
166167
return TokenFactory.getZerobusToken(
@@ -174,14 +175,15 @@ public <RecordType extends Message> CompletableFuture<ZerobusStream<RecordType>>
174175
}
175176
};
176177

177-
// Create gRPC stub once with token supplier - it will fetch fresh tokens as needed
178-
ZerobusGrpc.ZerobusStub stub =
179-
stubFactory.createStubWithTokenSupplier(
180-
serverEndpoint, tableProperties.getTableName(), tokenSupplier);
178+
// Create a stub supplier that generates a fresh stub with token supplier each time
179+
Supplier<ZerobusGrpc.ZerobusStub> stubSupplier =
180+
() ->
181+
stubFactory.createStubWithTokenSupplier(
182+
serverEndpoint, tableProperties.getTableName(), tokenSupplier);
181183

182184
ZerobusStream<RecordType> stream =
183185
new ZerobusStream<>(
184-
stub,
186+
stubSupplier,
185187
tableProperties,
186188
stubFactory,
187189
serverEndpoint,

src/main/java/com/databricks/zerobus/ZerobusStream.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.concurrent.ExecutorService;
1919
import java.util.concurrent.TimeoutException;
2020
import java.util.concurrent.atomic.AtomicBoolean;
21+
import java.util.function.Supplier;
2122
import org.slf4j.Logger;
2223
import org.slf4j.LoggerFactory;
2324

@@ -126,6 +127,7 @@ public class ZerobusStream<RecordType extends Message> {
126127
private static final int CREATE_STREAM_TIMEOUT_MS = 15000;
127128

128129
private ZerobusStub stub;
130+
private final Supplier<ZerobusStub> stubSupplier;
129131
final TableProperties<RecordType> tableProperties;
130132
private final ZerobusSdkStubFactory stubFactory;
131133
private final String serverEndpoint;
@@ -352,10 +354,10 @@ private CompletableFuture<Void> createStream() {
352354
() -> {
353355
CompletableFuture<Void> createStreamTry = new CompletableFuture<>();
354356

355-
// The stub was created once with a token supplier, so we don't recreate it here
356-
// The token supplier will provide a fresh token for each gRPC request
357+
// Get a fresh stub from the supplier
358+
stub = stubSupplier.get();
357359

358-
// Create the gRPC stream with the existing stub
360+
// Create the gRPC stream with the fresh stub
359361
streamCreatedEvent = Optional.of(new CompletableFuture<>());
360362
stream =
361363
Optional.of(
@@ -500,6 +502,9 @@ private void closeStream(boolean hardFailure, Optional<ZerobusException> excepti
500502
try {
501503
if (stream.isPresent()) {
502504
stream.get().onCompleted();
505+
if (hardFailure) {
506+
stream.get().cancel("Stream closed", null);
507+
}
503508
}
504509
} catch (Exception e) {
505510
// Ignore errors during stream cleanup - stream may already be closed
@@ -528,6 +533,7 @@ private void closeStream(boolean hardFailure, Optional<ZerobusException> excepti
528533
stream = Optional.empty();
529534
streamCreatedEvent = Optional.empty();
530535
streamId = Optional.empty();
536+
stub = null;
531537

532538
this.notifyAll();
533539
}
@@ -1073,6 +1079,7 @@ public void onNext(EphemeralStreamResponse response) {
10731079
String.format(
10741080
"Server will close the stream in %.3fms. Triggering stream recovery.",
10751081
durationMs));
1082+
streamFailureInfo.resetFailure(StreamFailureType.SERVER_CLOSED_STREAM);
10761083
handleStreamFailed(StreamFailureType.SERVER_CLOSED_STREAM, Optional.empty());
10771084
}
10781085
break;
@@ -1085,6 +1092,13 @@ public void onNext(EphemeralStreamResponse response) {
10851092

10861093
@Override
10871094
public void onError(Throwable t) {
1095+
synchronized (ZerobusStream.this) {
1096+
if (state == StreamState.CLOSED && !stream.isPresent()) {
1097+
logger.debug("Ignoring error on already closed stream: " + t.getMessage());
1098+
return;
1099+
}
1100+
}
1101+
10881102
Optional<Throwable> error = Optional.of(t);
10891103

10901104
if (t instanceof StatusRuntimeException) {
@@ -1336,7 +1350,7 @@ public void close() throws ZerobusException {
13361350
}
13371351

13381352
public ZerobusStream(
1339-
ZerobusStub stub,
1353+
Supplier<ZerobusStub> stubSupplier,
13401354
TableProperties<RecordType> tableProperties,
13411355
ZerobusSdkStubFactory stubFactory,
13421356
String serverEndpoint,
@@ -1347,7 +1361,8 @@ public ZerobusStream(
13471361
StreamConfigurationOptions options,
13481362
ExecutorService zerobusStreamExecutor,
13491363
ExecutorService ec) {
1350-
this.stub = stub;
1364+
this.stub = null;
1365+
this.stubSupplier = stubSupplier;
13511366
this.tableProperties = tableProperties;
13521367
this.stubFactory = stubFactory;
13531368
this.serverEndpoint = serverEndpoint;

src/test/java/com/databricks/zerobus/ZerobusSdkTest.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public class ZerobusSdkTest {
3838
private ZerobusSdk zerobusSdk;
3939
private ZerobusSdkStubFactory zerobusSdkStubFactory;
4040
private org.mockito.MockedStatic<TokenFactory> tokenFactoryMock;
41+
private io.grpc.stub.ClientCallStreamObserver<EphemeralStreamRequest> spiedStream;
4142

4243
@BeforeEach
4344
public void setUp() {
@@ -76,7 +77,10 @@ public void setUp() {
7677
(StreamObserver<EphemeralStreamResponse>) invocation.getArgument(0);
7778

7879
mockedGrpcServer.initialize(ackSender);
79-
return mockedGrpcServer.getMessageReceiver();
80+
81+
// Spy on the message receiver to verify cancel() is called
82+
spiedStream = spy(mockedGrpcServer.getMessageReceiver());
83+
return spiedStream;
8084
})
8185
.when(zerobusStub)
8286
.ephemeralStream(any());
@@ -378,4 +382,40 @@ public void testCallbackExceptionHandling() throws Exception {
378382
stream.close();
379383
assertEquals(StreamState.CLOSED, stream.getState());
380384
}
385+
386+
@Test
387+
public void testGrpcStreamIsCancelledOnClose() throws Exception {
388+
// Test that the underlying gRPC stream is properly cancelled when stream.close() is called
389+
mockedGrpcServer.injectAckRecord(0);
390+
391+
TableProperties<CityPopulationTableRow> tableProperties =
392+
new TableProperties<>("test-table", CityPopulationTableRow.getDefaultInstance());
393+
StreamConfigurationOptions options =
394+
StreamConfigurationOptions.builder().setRecovery(false).build();
395+
396+
ZerobusStream<CityPopulationTableRow> stream =
397+
zerobusSdk.createStream(tableProperties, "client-id", "client-secret", options).get();
398+
399+
assertEquals(StreamState.OPENED, stream.getState());
400+
401+
// Ingest one record
402+
CompletableFuture<Void> writeCompleted =
403+
stream.ingestRecord(
404+
CityPopulationTableRow.newBuilder()
405+
.setCityName("test-city")
406+
.setPopulation(1000)
407+
.build());
408+
409+
writeCompleted.get(5, TimeUnit.SECONDS);
410+
411+
// Close the stream
412+
stream.close();
413+
assertEquals(StreamState.CLOSED, stream.getState());
414+
415+
// Verify that cancel() was called on the gRPC stream
416+
verify(spiedStream, times(1)).cancel(anyString(), any());
417+
418+
// Also verify onCompleted() was called
419+
verify(spiedStream, times(1)).onCompleted();
420+
}
381421
}

0 commit comments

Comments
 (0)