Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/main/java/com/databricks/zerobus/ZerobusSdk.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -160,7 +161,7 @@ public <RecordType extends Message> CompletableFuture<ZerobusStream<RecordType>>
logger.debug("Creating stream for table: " + tableProperties.getTableName());

// Create a token supplier that generates a fresh token for each gRPC request
java.util.function.Supplier<String> tokenSupplier =
Supplier<String> tokenSupplier =
() -> {
try {
return TokenFactory.getZerobusToken(
Expand All @@ -174,14 +175,15 @@ public <RecordType extends Message> CompletableFuture<ZerobusStream<RecordType>>
}
};

// Create gRPC stub once with token supplier - it will fetch fresh tokens as needed
ZerobusGrpc.ZerobusStub stub =
stubFactory.createStubWithTokenSupplier(
serverEndpoint, tableProperties.getTableName(), tokenSupplier);
// Create a stub supplier that generates a fresh stub with token supplier each time
Supplier<ZerobusGrpc.ZerobusStub> stubSupplier =
() ->
stubFactory.createStubWithTokenSupplier(
serverEndpoint, tableProperties.getTableName(), tokenSupplier);

ZerobusStream<RecordType> stream =
new ZerobusStream<>(
stub,
stubSupplier,
tableProperties,
stubFactory,
serverEndpoint,
Expand Down
25 changes: 20 additions & 5 deletions src/main/java/com/databricks/zerobus/ZerobusStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -126,6 +127,7 @@ public class ZerobusStream<RecordType extends Message> {
private static final int CREATE_STREAM_TIMEOUT_MS = 15000;

private ZerobusStub stub;
private final Supplier<ZerobusStub> stubSupplier;
final TableProperties<RecordType> tableProperties;
private final ZerobusSdkStubFactory stubFactory;
private final String serverEndpoint;
Expand Down Expand Up @@ -352,10 +354,10 @@ private CompletableFuture<Void> createStream() {
() -> {
CompletableFuture<Void> createStreamTry = new CompletableFuture<>();

// The stub was created once with a token supplier, so we don't recreate it here
// The token supplier will provide a fresh token for each gRPC request
// Get a fresh stub from the supplier
stub = stubSupplier.get();

// Create the gRPC stream with the existing stub
// Create the gRPC stream with the fresh stub
streamCreatedEvent = Optional.of(new CompletableFuture<>());
stream =
Optional.of(
Expand Down Expand Up @@ -500,6 +502,9 @@ private void closeStream(boolean hardFailure, Optional<ZerobusException> excepti
try {
if (stream.isPresent()) {
stream.get().onCompleted();
if (hardFailure) {
stream.get().cancel("Stream closed", null);
}
}
} catch (Exception e) {
// Ignore errors during stream cleanup - stream may already be closed
Expand Down Expand Up @@ -528,6 +533,7 @@ private void closeStream(boolean hardFailure, Optional<ZerobusException> excepti
stream = Optional.empty();
streamCreatedEvent = Optional.empty();
streamId = Optional.empty();
stub = null;

this.notifyAll();
}
Expand Down Expand Up @@ -1073,6 +1079,7 @@ public void onNext(EphemeralStreamResponse response) {
String.format(
"Server will close the stream in %.3fms. Triggering stream recovery.",
durationMs));
streamFailureInfo.resetFailure(StreamFailureType.SERVER_CLOSED_STREAM);
handleStreamFailed(StreamFailureType.SERVER_CLOSED_STREAM, Optional.empty());
}
break;
Expand All @@ -1085,6 +1092,13 @@ public void onNext(EphemeralStreamResponse response) {

@Override
public void onError(Throwable t) {
synchronized (ZerobusStream.this) {
if (state == StreamState.CLOSED && !stream.isPresent()) {
logger.debug("Ignoring error on already closed stream: " + t.getMessage());
return;
}
}

Optional<Throwable> error = Optional.of(t);

if (t instanceof StatusRuntimeException) {
Expand Down Expand Up @@ -1336,7 +1350,7 @@ public void close() throws ZerobusException {
}

public ZerobusStream(
ZerobusStub stub,
Supplier<ZerobusStub> stubSupplier,
TableProperties<RecordType> tableProperties,
ZerobusSdkStubFactory stubFactory,
String serverEndpoint,
Expand All @@ -1347,7 +1361,8 @@ public ZerobusStream(
StreamConfigurationOptions options,
ExecutorService zerobusStreamExecutor,
ExecutorService ec) {
this.stub = stub;
this.stub = null;
this.stubSupplier = stubSupplier;
this.tableProperties = tableProperties;
this.stubFactory = stubFactory;
this.serverEndpoint = serverEndpoint;
Expand Down
42 changes: 41 additions & 1 deletion src/test/java/com/databricks/zerobus/ZerobusSdkTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class ZerobusSdkTest {
private ZerobusSdk zerobusSdk;
private ZerobusSdkStubFactory zerobusSdkStubFactory;
private org.mockito.MockedStatic<TokenFactory> tokenFactoryMock;
private io.grpc.stub.ClientCallStreamObserver<EphemeralStreamRequest> spiedStream;

@BeforeEach
public void setUp() {
Expand Down Expand Up @@ -76,7 +77,10 @@ public void setUp() {
(StreamObserver<EphemeralStreamResponse>) invocation.getArgument(0);

mockedGrpcServer.initialize(ackSender);
return mockedGrpcServer.getMessageReceiver();

// Spy on the message receiver to verify cancel() is called
spiedStream = spy(mockedGrpcServer.getMessageReceiver());
return spiedStream;
})
.when(zerobusStub)
.ephemeralStream(any());
Expand Down Expand Up @@ -378,4 +382,40 @@ public void testCallbackExceptionHandling() throws Exception {
stream.close();
assertEquals(StreamState.CLOSED, stream.getState());
}

@Test
public void testGrpcStreamIsCancelledOnClose() throws Exception {
// Test that the underlying gRPC stream is properly cancelled when stream.close() is called
mockedGrpcServer.injectAckRecord(0);

TableProperties<CityPopulationTableRow> tableProperties =
new TableProperties<>("test-table", CityPopulationTableRow.getDefaultInstance());
StreamConfigurationOptions options =
StreamConfigurationOptions.builder().setRecovery(false).build();

ZerobusStream<CityPopulationTableRow> stream =
zerobusSdk.createStream(tableProperties, "client-id", "client-secret", options).get();

assertEquals(StreamState.OPENED, stream.getState());

// Ingest one record
CompletableFuture<Void> writeCompleted =
stream.ingestRecord(
CityPopulationTableRow.newBuilder()
.setCityName("test-city")
.setPopulation(1000)
.build());

writeCompleted.get(5, TimeUnit.SECONDS);

// Close the stream
stream.close();
assertEquals(StreamState.CLOSED, stream.getState());

// Verify that cancel() was called on the gRPC stream
verify(spiedStream, times(1)).cancel(anyString(), any());

// Also verify onCompleted() was called
verify(spiedStream, times(1)).onCompleted();
}
}