Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sqlx-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod socket;
pub mod tls;

pub use socket::{
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
connect_socket, connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket,
WriteBuffer,
};
14 changes: 14 additions & 0 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ pub async fn connect_tcp<Ws: WithSocket>(
}
}

/// Connect using a pre-connected socket that implements [`Socket`].
///
/// This allows using custom transport layers (e.g., vsock, QUIC, or any
/// `AsyncRead + AsyncWrite` type) with SQLx database connections.
///
/// The socket will be passed through the `with_socket` handler, which
/// typically performs TLS upgrade negotiation.
pub async fn connect_socket<S: Socket, Ws: WithSocket>(
socket: S,
with_socket: Ws,
) -> crate::Result<Ws::Output> {
Ok(with_socket.with_socket(socket).await)
}

/// Open a TCP socket to `host` and `port`.
///
/// If `host` is a hostname, attempt to connect to each address it resolves to.
Expand Down
18 changes: 16 additions & 2 deletions sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,29 @@ impl MySqlConnection {

let stream = handshake?;

Ok(Self {
Ok(Self::establish_with_stream(stream, options))
}

pub(crate) async fn establish_with_socket<S: Socket>(
socket: S,
options: &MySqlConnectOptions,
) -> Result<Self, Error> {
let do_handshake = DoHandshake::new(options)?;
let stream = do_handshake.with_socket(socket).await?;

Ok(Self::establish_with_stream(stream, options))
}

fn establish_with_stream(stream: MySqlStream, options: &MySqlConnectOptions) -> Self {
Self {
inner: Box::new(MySqlConnectionInner {
stream,
transaction_depth: 0,
status_flags: Default::default(),
cache_statement: StatementCache::new(options.statement_cache_capacity),
log_settings: options.log_settings.clone(),
}),
})
}
}
}

Expand Down
32 changes: 32 additions & 0 deletions sqlx-mysql/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,38 @@ pub(crate) struct MySqlConnectionInner {
}

impl MySqlConnection {
/// Connect to a MySQL database using a pre-connected socket.
///
/// This allows using custom transport layers such as vsock, QUIC,
/// or any type that implements [`sqlx_core::net::Socket`].
///
/// The provided socket will go through TLS upgrade negotiation based on the
/// SSL mode configured in `options`.
///
/// # Example
///
/// ```rust,ignore
/// use sqlx::mysql::{MySqlConnectOptions, MySqlConnection};
///
/// # async fn example() -> sqlx::Result<()> {
/// let socket: tokio::net::TcpStream = todo!();
/// let options = MySqlConnectOptions::new()
/// .username("root")
/// .database("mydb");
///
/// let _conn = MySqlConnection::connect_socket(socket, &options).await?;
/// # Ok(())
/// # }
/// ```
pub async fn connect_socket<S: sqlx_core::net::Socket>(
socket: S,
options: &MySqlConnectOptions,
) -> Result<Self, Error> {
let mut conn = Self::establish_with_socket(socket, options).await?;
options.configure_session(&mut conn).await?;
Ok(conn)
}

pub(crate) fn in_transaction(&self) -> bool {
self.inner
.status_flags
Expand Down
33 changes: 20 additions & 13 deletions sqlx-mysql/src/options/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,26 @@ impl ConnectOptions for MySqlConnectOptions {
{
let mut conn = MySqlConnection::establish(self).await?;

// After the connection is established, we initialize by configuring a few
// connection parameters
self.configure_session(&mut conn).await?;

Ok(conn)
}

fn log_statements(mut self, level: LevelFilter) -> Self {
self.log_settings.log_statements(level);
self
}

fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
self.log_settings.log_slow_statements(level, duration);
self
}
}

impl MySqlConnectOptions {
/// After the connection is established, initialize by configuring
/// connection parameters (sql_mode, time_zone, charset).
pub(crate) async fn configure_session(&self, conn: &mut MySqlConnection) -> Result<(), Error> {
// https://mariadb.com/kb/en/sql-mode/

// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
Expand Down Expand Up @@ -88,16 +105,6 @@ impl ConnectOptions for MySqlConnectOptions {
.await?;
}

Ok(conn)
}

fn log_statements(mut self, level: LevelFilter) -> Self {
self.log_settings.log_statements(level);
self
}

fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
self.log_settings.log_slow_statements(level, duration);
self
Ok(())
}
}
18 changes: 16 additions & 2 deletions sqlx-postgres/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::io::StatementId;
use crate::message::{
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
};
use crate::net::Socket;
use crate::{PgConnectOptions, PgConnection};

use super::PgConnectionInner;
Expand All @@ -16,9 +17,22 @@ use super::PgConnectionInner;

impl PgConnection {
pub(crate) async fn establish(options: &PgConnectOptions) -> Result<Self, Error> {
// Upgrade to TLS if we were asked to and the server supports it
let mut stream = PgStream::connect(options).await?;
let stream = PgStream::connect(options).await?;
Self::establish_with_stream(stream, options).await
}

pub(crate) async fn establish_with_socket<S: Socket>(
socket: S,
options: &PgConnectOptions,
) -> Result<Self, Error> {
let stream = PgStream::connect_socket(socket, options).await?;
Self::establish_with_stream(stream, options).await
}

async fn establish_with_stream(
mut stream: PgStream,
options: &PgConnectOptions,
) -> Result<Self, Error> {
// To begin a session, a frontend opens a connection to the server
// and sends a startup message.

Expand Down
30 changes: 30 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ pub(crate) struct TableColumns {
}

impl PgConnection {
/// Connect to a PostgreSQL database using a pre-connected socket.
///
/// This allows using custom transport layers such as vsock, QUIC,
/// or any type that implements [`sqlx_core::net::Socket`].
///
/// The provided socket will go through TLS upgrade negotiation based on the
/// SSL mode configured in `options`.
///
/// # Example
///
/// ```rust,ignore
/// use sqlx::postgres::{PgConnectOptions, PgConnection};
///
/// # async fn example() -> sqlx::Result<()> {
/// let socket: tokio::net::TcpStream = todo!();
/// let options = PgConnectOptions::new()
/// .username("postgres")
/// .database("mydb");
///
/// let _conn = PgConnection::connect_socket(socket, &options).await?;
/// # Ok(())
/// # }
/// ```
pub async fn connect_socket<S: sqlx_core::net::Socket>(
socket: S,
options: &PgConnectOptions,
) -> Result<Self, Error> {
Self::establish_with_socket(socket, options).await
}

/// the version number of the server in `libpq` format
pub fn server_version_num(&self) -> Option<u32> {
self.inner.stream.server_version_num
Expand Down
16 changes: 16 additions & 0 deletions sqlx-postgres/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ impl PgStream {
})
}

pub(super) async fn connect_socket<S: Socket>(
socket: S,
options: &PgConnectOptions,
) -> Result<Self, Error> {
let socket = net::connect_socket(socket, MaybeUpgradeTls(options)).await?;

let socket = socket?;

Ok(Self {
inner: BufferedSocket::new(socket),
notifications: None,
parameter_statuses: BTreeMap::default(),
server_version_num: None,
})
}

#[inline(always)]
pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
self.write(EncodeMessage(message))
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ pub mod decode {

pub use self::decode::Decode;

/// Networking traits for custom transport implementations.
pub mod net {
pub use sqlx_core::io::ReadBuf;
pub use sqlx_core::net::Socket;
}

/// Types and traits for the `query` family of functions and macros.
pub mod query {
pub use sqlx_core::query::{Map, Query};
Expand Down
Loading