Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 141 additions & 2 deletions bitreq/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,138 @@
use std::collections::{hash_map, HashMap, VecDeque};
use std::sync::{Arc, Mutex};

#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
use crate::connection::tls_config::{TlsConfig, TlsConfigBuilder};
use crate::connection::AsyncConnection;
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
use crate::{Error, Request, Response};

#[derive(Clone)]
pub(crate) struct ClientConfig {
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
pub(crate) tls: Option<TlsConfig>,
}

pub struct ClientBuilder {
capacity: usize,
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
tls_config: Option<TlsConfigBuilder>,
}

/// Builder for configuring a `Client` with custom settings.
///
/// # Example
///
/// ```no_run
/// # async fn example() -> Result<(), bitreq::Error> {
/// use bitreq::{Client, RequestExt};
///
/// let client = Client::builder().with_capacity(20).build()?;
///
/// let response = bitreq::get("https://example.com")
/// .send_async_with_client(&client)
/// .await?;
/// # Ok(())
/// # }
/// ```
impl ClientBuilder {
/// Creates a new `ClientBuilder` with a default pool capacity of 10.
pub fn new() -> Self {
Self {
capacity: 10,
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
tls_config: None,
}
}

/// Sets the maximum number of connections to keep in the pool.
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}

/// Builds the `Client` with the configured settings.
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
pub fn build(self) -> Result<Client, Error> {
let build_config = if let Some(builder) = self.tls_config {
let tls_config = builder.build()?;
Some(ClientConfig { tls: Some(tls_config) })
} else {
None
};
let client_config = build_config.map(Arc::new);

Ok(Client {
r#async: Arc::new(Mutex::new(ClientImpl {
connections: HashMap::new(),
lru_order: VecDeque::new(),
capacity: self.capacity,
client_config,
})),
})
}

/// Builds the `Client` with the configured settings.
#[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))]
pub fn build(self) -> Result<Client, Error> {
Ok(Client {
r#async: Arc::new(Mutex::new(ClientImpl {
connections: HashMap::new(),
lru_order: VecDeque::new(),
capacity: self.capacity,
client_config: None,
})),
})
}

/// Adds a custom DER-encoded root certificate for TLS verification.
/// The certificate must be provided in DER format. This method accepts any type
/// that can be converted into a `Vec<u8>`.
/// The certificate is appended to the default trust store rather than replacing it.
/// The trust store used depends on the TLS backend: system certificates for native-tls,
/// Mozilla's root certificates(rustls-webpki) and/or system certificates(rustls-native-certs) for rustls.
///
/// # Example
///
/// ```no_run
/// # use bitreq::Client;
/// # async fn example() -> Result<(), bitreq::Error> {
/// let client = Client::builder()
/// .with_root_certificate(include_bytes!("../tests/test_cert.der"))?
/// .build()?;
/// # Ok(())
/// # }
/// ```
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
pub fn with_root_certificate<T: Into<Vec<u8>>>(mut self, cert_der: T) -> Result<Self, Error> {
let cert_der = cert_der.into();
if let Some(ref mut tls_config) = self.tls_config {
tls_config.append_certificate(cert_der)?;

return Ok(self);
}

self.tls_config = Some(TlsConfigBuilder::new(Some(cert_der))?);
Ok(self)
}

/// Disables default root certificates for TLS connections.
/// Returns [`Error::InvalidTlsConfig`] if TLS has not been configured.
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
pub fn disable_default_certificates(mut self) -> Result<Self, Error> {
match self.tls_config {
Some(ref mut tls_config) => tls_config.disable_default_certificates()?,
None => return Err(Error::InvalidTlsConfig),
};

Ok(self)
}
}

impl Default for ClientBuilder {
fn default() -> Self { Self::new() }
}

/// A client that caches connections for reuse.
///
/// The client maintains a pool of up to `capacity` connections, evicting
Expand All @@ -39,10 +167,11 @@ struct ClientImpl<T> {
connections: HashMap<ConnectionKey, Arc<T>>,
lru_order: VecDeque<ConnectionKey>,
capacity: usize,
client_config: Option<Arc<ClientConfig>>,
}

impl Client {
/// Creates a new `Client` with the specified connection cache capacity.
/// Creates a new `Client` with the specified connection pool capacity.
///
/// # Arguments
///
Expand All @@ -54,10 +183,14 @@ impl Client {
connections: HashMap::new(),
lru_order: VecDeque::new(),
capacity,
client_config: None,
})),
}
}

/// Create a builder for a client
pub fn builder() -> ClientBuilder { ClientBuilder::new() }

/// Sends a request asynchronously using a cached connection if available.
pub async fn send_async(&self, request: Request) -> Result<Response, Error> {
let parsed_request = ParsedRequest::new(request)?;
Expand All @@ -77,7 +210,13 @@ impl Client {
let conn = if let Some(conn) = conn_opt {
conn
} else {
let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?;
let client_config = {
let state = self.r#async.lock().unwrap();
state.client_config.as_ref().map(Arc::clone)
};

let connection =
AsyncConnection::new(key, parsed_request.timeout_at, client_config).await?;
let connection = Arc::new(connection);

let mut state = self.r#async.lock().unwrap();
Expand Down
85 changes: 59 additions & 26 deletions bitreq/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ use tokio::net::TcpStream as AsyncTcpStream;
#[cfg(feature = "async")]
use tokio::sync::Mutex as AsyncMutex;

#[cfg(feature = "async")]
use crate::client::ClientConfig;
use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest};
#[cfg(feature = "async")]
use crate::Response;
use crate::{Error, Method, ResponseLazy};

type UnsecuredStream = TcpStream;

#[cfg(feature = "rustls")]
#[cfg(any(feature = "rustls", feature = "native-tls"))]
mod rustls_stream;
#[cfg(feature = "rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
pub(crate) mod tls_config;
#[cfg(any(feature = "rustls", feature = "native-tls"))]
type SecuredStream = rustls_stream::SecuredStream;

pub(crate) enum HttpStream {
Unsecured(UnsecuredStream, Option<Instant>),
#[cfg(feature = "rustls")]
#[cfg(any(feature = "rustls", feature = "native-tls"))]
Secured(Box<SecuredStream>, Option<Instant>),
#[cfg(feature = "async")]
Buffer(std::io::Cursor<Vec<u8>>),
Expand Down Expand Up @@ -81,7 +85,7 @@ impl Read for HttpStream {
timeout(inner, *timeout_at)?;
inner.read(buf)
}
#[cfg(feature = "rustls")]
#[cfg(any(feature = "rustls", feature = "native-tls"))]
HttpStream::Secured(inner, timeout_at) => {
timeout(inner.get_ref(), *timeout_at)?;
inner.read(buf)
Expand Down Expand Up @@ -111,7 +115,7 @@ impl Write for HttpStream {
set_socket_write_timeout(inner, *timeout_at)?;
inner.write(buf)
}
#[cfg(feature = "rustls")]
#[cfg(any(feature = "rustls", feature = "native-tls"))]
HttpStream::Secured(inner, timeout_at) => {
set_socket_write_timeout(inner.get_ref(), *timeout_at)?;
inner.write(buf)
Expand All @@ -137,7 +141,7 @@ impl Write for HttpStream {
set_socket_write_timeout(inner, *timeout_at)?;
inner.flush()
}
#[cfg(feature = "rustls")]
#[cfg(any(feature = "rustls", feature = "native-tls"))]
HttpStream::Secured(inner, timeout_at) => {
set_socket_write_timeout(inner.get_ref(), *timeout_at)?;
inner.flush()
Expand All @@ -158,13 +162,13 @@ impl Write for HttpStream {
}
}

#[cfg(feature = "tokio-rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
type AsyncSecuredStream = rustls_stream::AsyncSecuredStream;

#[cfg(feature = "async")]
pub(crate) enum AsyncHttpStream {
Unsecured(AsyncTcpStream),
#[cfg(feature = "tokio-rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
Secured(Box<AsyncSecuredStream>),
}

Expand All @@ -177,7 +181,7 @@ impl AsyncRead for AsyncHttpStream {
) -> Poll<io::Result<()>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf),
#[cfg(feature = "tokio-rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf),
}
}
Expand All @@ -192,23 +196,23 @@ impl AsyncWrite for AsyncHttpStream {
) -> Poll<io::Result<usize>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(feature = "tokio-rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx),
#[cfg(feature = "tokio-rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(feature = "tokio-rustls")]
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx),
}
}
Expand Down Expand Up @@ -238,6 +242,7 @@ struct AsyncConnectionState {
/// Defaults to 60 seconds after open to align with nginx's default timeout of 75 seconds, but
/// can be overridden by the `Keep-Alive` header.
socket_new_requests_timeout: Mutex<Instant>,
client_config: Option<Arc<ClientConfig>>,
}

#[cfg(feature = "async")]
Expand Down Expand Up @@ -266,15 +271,15 @@ impl AsyncConnection {
pub(crate) async fn new(
params: ConnectionParams<'_>,
timeout_at: Option<Instant>,
client_config: Option<Arc<ClientConfig>>,
) -> Result<AsyncConnection, Error> {
let client_config_ref = &client_config;

let future = async move {
let socket = Self::connect(params).await?;

if params.https {
#[cfg(not(feature = "tokio-rustls"))]
return Err(Error::HttpsFeatureNotEnabled);
#[cfg(feature = "tokio-rustls")]
rustls_stream::wrap_async_stream(socket, params.host).await
Self::wrap_async_stream(socket, params.host, client_config_ref).await
} else {
Ok(AsyncHttpStream::Unsecured(socket))
}
Expand All @@ -295,9 +300,35 @@ impl AsyncConnection {
readable_request_id: AtomicUsize::new(0),
min_dropped_reader_id: AtomicUsize::new(usize::MAX),
socket_new_requests_timeout: Mutex::new(Instant::now() + Duration::from_secs(60)),
client_config,
}))))
}

/// Call the correct wrapper function depending on whether client_configs are present
#[cfg(any(feature = "tokio-rustls", feature = "tokio-native-tls"))]
async fn wrap_async_stream(
socket: AsyncTcpStream,
host: &str,
client_config: &Option<Arc<ClientConfig>>,
) -> Result<AsyncHttpStream, Error> {
if let Some(client_config) = client_config {
let tls_config = client_config.tls.as_ref().unwrap().clone();
rustls_stream::wrap_async_stream_with_configs(socket, host, tls_config).await
} else {
rustls_stream::wrap_async_stream(socket, host).await
}
}

/// Error treatment function, should not be called under normal circustances
#[cfg(not(any(feature = "tokio-rustls", feature = "tokio-native-tls")))]
async fn wrap_async_stream(
_socket: AsyncTcpStream,
_host: &str,
_client_config: &Option<Arc<ClientConfig>>,
) -> Result<AsyncHttpStream, Error> {
Err(Error::HttpsFeatureNotEnabled)
}

async fn tcp_connect(host: &str, port: u16) -> Result<AsyncTcpStream, Error> {
#[cfg(feature = "log")]
log::trace!("Looking up host {host}");
Expand Down Expand Up @@ -446,9 +477,13 @@ impl AsyncConnection {
retry_new_connection!(_internal);
};
(_internal) => {
let new_connection =
AsyncConnection::new(request.connection_params(), request.timeout_at)
.await?;
let config = conn.client_config.as_ref().map(Arc::clone);
let new_connection = AsyncConnection::new(
request.connection_params(),
request.timeout_at,
config,
)
.await?;
*self.0.lock().unwrap() = Arc::clone(&*new_connection.0.lock().unwrap());
core::mem::drop(read);
// Note that this cannot recurse infinitely as we'll always be able to send at
Expand Down Expand Up @@ -653,13 +688,10 @@ impl Connection {
let socket = Self::connect(params, timeout_at)?;

let stream = if params.https {
#[cfg(not(feature = "rustls"))]
#[cfg(not(any(feature = "rustls", feature = "native-tls")))]
return Err(Error::HttpsFeatureNotEnabled);
#[cfg(feature = "rustls")]
{
let tls = rustls_stream::wrap_stream(socket, params.host)?;
HttpStream::Secured(Box::new(tls), timeout_at)
}
#[cfg(any(feature = "rustls", feature = "native-tls"))]
rustls_stream::wrap_stream(socket, params.host)?
} else {
HttpStream::create_unsecured(socket, timeout_at)
};
Expand Down Expand Up @@ -806,7 +838,8 @@ async fn async_handle_redirects(
let new_connection;
if needs_new_connection {
new_connection =
AsyncConnection::new(request.connection_params(), request.timeout_at).await?;
AsyncConnection::new(request.connection_params(), request.timeout_at, None)
.await?;
connection = &new_connection;
}
connection.send(request).await
Expand Down
Loading
Loading