Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ assert_eq!(

# Testing

You can set the address and port of the test instance using `TOKIO_ZOOKEEPER_TEST_HOST` and `TOKIO_ZOOKEEPER_TEST_PORT` respectively.

The the default is `127.0.0.1:2181`.

1. Start a Zookeeper instance, e.g. using `docker run -p 2181:2181 zookeeper`
2. Run `cargo test`

Expand Down
84 changes: 73 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,18 @@ impl ZooKeeperBuilder {
let stream = tokio::net::TcpStream::connect(addr)
.await
.whatever_context("connect failed")?;
Ok((self.handshake(*addr, stream, tx).await?, rx))
Ok((self.handshake(*addr, stream, Some(tx)).await?, rx))
}

/// Connect to a ZooKeeper server instance at the given address, but without returning a
/// watcher stream.
///
/// See [`ZooKeeperBuilder::connect_without_watcher`].
pub async fn connect_without_watcher(self, addr: &SocketAddr) -> Result<ZooKeeper, Error> {
let stream = tokio::net::TcpStream::connect(addr)
.await
.whatever_context("connect failed")?;
self.handshake(*addr, stream, None).await
}

/// Set the ZooKeeper [session expiry
Expand All @@ -294,7 +305,7 @@ impl ZooKeeperBuilder {
self,
addr: SocketAddr,
stream: tokio::net::TcpStream,
default_watcher: futures::channel::mpsc::UnboundedSender<WatchedEvent>,
default_watcher: Option<futures::channel::mpsc::UnboundedSender<WatchedEvent>>,
) -> Result<ZooKeeper, Error> {
let request = proto::Request::Connect {
protocol_version: 0,
Expand Down Expand Up @@ -327,6 +338,16 @@ impl ZooKeeper {
ZooKeeperBuilder::default().connect(addr).await
}

/// Connect to a ZooKeeper server instance at the given address with default parameters,
/// but without returning a watcher stream.
///
/// See [`ZooKeeperBuilder::connect_without_watcher`].
pub async fn connect_without_watcher(addr: &SocketAddr) -> Result<Self, Error> {
ZooKeeperBuilder::default()
.connect_without_watcher(addr)
.await
}

/// Create a node with the given `path` with `data` as its contents.
///
/// The `mode` argument specifies additional options for the newly created node.
Expand Down Expand Up @@ -754,6 +775,8 @@ mod tests {
use super::*;

use futures::StreamExt;
use std::env;
use std::net::ToSocketAddrs;
use tracing::Level;

fn init_tracing_subscriber() {
Expand All @@ -762,12 +785,30 @@ mod tests {
.try_init();
}

// Use environment variables to override default connection otherwise
// default to localhost:127.0.0.1:2181
fn get_test_zookeeper_addr() -> SocketAddr {
let host =
env::var("TOKIO_ZOOKEEPER_TEST_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());

let port: u16 = env::var("TOKIO_ZOOKEEPER_TEST_PORT")
.unwrap_or_else(|_| "2181".to_string())
.parse()
.expect("TOKIO_ZOOKEEPER_TEST_PORT must be a valid u16");

format!("{host}:{port}")
.to_socket_addrs()
.expect("Invalid host:port")
.next()
.expect("Host resolved but returned no addresses")
}

#[tokio::test]
async fn it_works() {
init_tracing_subscriber();
let builder = ZooKeeperBuilder::default();

let connect_addr = "127.0.0.1:2181".parse().unwrap();
let connect_addr = get_test_zookeeper_addr();
let (zk, w) = builder.connect(&connect_addr).await.unwrap();
let (exists_w, stat) = zk.with_watcher().exists("/foo").await.unwrap();
assert_eq!(stat, None);
Expand Down Expand Up @@ -873,7 +914,7 @@ mod tests {

#[tokio::test]
async fn example() {
let connect_addr = "127.0.0.1:2181".parse().unwrap();
let connect_addr = get_test_zookeeper_addr();
let (zk, default_watcher) = ZooKeeper::connect(&connect_addr).await.unwrap();

// let's first check if /example exists. the .watch() causes us to be notified
Expand Down Expand Up @@ -958,10 +999,9 @@ mod tests {
async fn acl_test() {
init_tracing_subscriber();
let builder = ZooKeeperBuilder::default();
let connect_addr = get_test_zookeeper_addr();

let (zk, _) = (builder.connect(&"127.0.0.1:2181".parse().unwrap()))
.await
.unwrap();
let (zk, _) = (builder.connect(&connect_addr)).await.unwrap();
let _ = zk
.create(
"/acl_test",
Expand Down Expand Up @@ -1021,11 +1061,9 @@ mod tests {
}
Result::<_, Error>::Ok(res)
}
let connect_addr = get_test_zookeeper_addr();

let (zk, _) = builder
.connect(&"127.0.0.1:2181".parse().unwrap())
.await
.unwrap();
let (zk, _) = builder.connect(&connect_addr).await.unwrap();

let res = zk
.multi()
Expand Down Expand Up @@ -1133,4 +1171,28 @@ mod tests {

drop(zk); // make Packetizer idle
}

#[tokio::test]
async fn connect_without_watcher_test() {
init_tracing_subscriber();
let connect_addr = get_test_zookeeper_addr();

let zk = ZooKeeper::connect_without_watcher(&connect_addr).await.unwrap();

let path = zk
.create(
"/no_watcher_test",
&b"Hello world"[..],
Acl::open_unsafe(),
CreateMode::Persistent,
)
.await
.unwrap();
assert_eq!(path.as_deref(), Ok("/no_watcher_test"));

let res = zk.delete("/no_watcher_test", None).await.unwrap();
assert_eq!(res, Ok(()));

drop(zk); // make Packetizer idle
}
}
11 changes: 7 additions & 4 deletions src/proto/active_packetizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
default_watcher: &mut mpsc::UnboundedSender<WatchedEvent>,
default_watcher: &mut Option<mpsc::UnboundedSender<WatchedEvent>>,
) -> Poll<Result<(), Error>>
where
S: AsyncRead,
Expand Down Expand Up @@ -362,8 +362,11 @@ where
.expect("tried to remove watcher that didn't exist");
}

// NOTE: ignoring error, because the user may not care about events
let _ = default_watcher.unbounded_send(e);
// Handle optional watcher stream for connect_without_watcher
if let Some(w) = &default_watcher {
// NOTE: ignoring error, because the user may not care about events
let _ = w.unbounded_send(e);
}
} else if xid == -2 {
// response to ping -- empty response
trace!("got response to heartbeat");
Expand Down Expand Up @@ -445,7 +448,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
exiting: bool,
default_watcher: &mut mpsc::UnboundedSender<WatchedEvent>,
default_watcher: &mut Option<mpsc::UnboundedSender<WatchedEvent>>,
) -> Poll<Result<(), Error>> {
let r = self.as_mut().poll_read(cx, default_watcher)?;

Expand Down
6 changes: 3 additions & 3 deletions src/proto/packetizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ where
state: PacketizerState<S>,

/// Watcher to send watch events to.
default_watcher: mpsc::UnboundedSender<WatchedEvent>,
default_watcher: Option<mpsc::UnboundedSender<WatchedEvent>>,

/// Incoming requests
rx: mpsc::UnboundedReceiver<(Request, oneshot::Sender<Result<Response, ZkError>>)>,
Expand All @@ -54,7 +54,7 @@ where
pub(crate) fn new(
addr: S::Addr,
stream: S,
default_watcher: mpsc::UnboundedSender<WatchedEvent>,
default_watcher: Option<mpsc::UnboundedSender<WatchedEvent>>,
) -> Enqueuer
where
S: Send + 'static + AsyncRead + AsyncWrite,
Expand Down Expand Up @@ -98,7 +98,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context,
exiting: bool,
default_watcher: &mut mpsc::UnboundedSender<WatchedEvent>,
default_watcher: &mut Option<mpsc::UnboundedSender<WatchedEvent>>,
) -> Poll<Result<(), Error>> {
let ap = match self.as_mut().project() {
PacketizerStateProj::Connected(ref mut ap) => {
Expand Down