Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions crates/sqlx-sqlite-conn-mgr/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::error;

/// Analysis limit for PRAGMA optimize on close.
/// SQLite recommends 100-1000 for older versions; 3.46.0+ handles automatically.
/// See: https://www.sqlite.org/lang_analyze.html#recommended_usage_pattern
const OPTIMIZE_ANALYSIS_LIMIT: u32 = 400;

/// SQLite database with connection pooling for concurrent reads and optional exclusive writes.
///
/// Once the database is opened it can be used for read-only operations by calling `read_pool()`.
Expand Down Expand Up @@ -144,16 +149,11 @@ impl SqliteDatabase {
drop(conn); // Close immediately after creating the file
}

// Enable PRAGMA optimize on close as recommended by SQLite for long-lived databases.
// SQLite recommends analysis_limit values between 100-1000 for older versions;
// SQLite 3.46.0+ handles limits automatically.
// https://www.sqlite.org/lang_analyze.html#recommended_usage_pattern
//
// Create read pool with read-only connections
let read_options = SqliteConnectOptions::new()
.filename(&path)
.read_only(true)
.optimize_on_close(true, 400);
.optimize_on_close(true, OPTIMIZE_ANALYSIS_LIMIT);

let read_pool = SqlitePoolOptions::new()
.max_connections(config.max_read_connections)
Expand All @@ -168,7 +168,7 @@ impl SqliteDatabase {
let write_options = SqliteConnectOptions::new()
.filename(&path)
.read_only(false)
.optimize_on_close(true, 400);
.optimize_on_close(true, OPTIMIZE_ANALYSIS_LIMIT);

let write_conn = SqlitePoolOptions::new()
.max_connections(1)
Expand Down Expand Up @@ -250,8 +250,12 @@ impl SqliteDatabase {
// Acquire connection from pool (max=1 ensures exclusive access)
let mut conn = self.write_conn.acquire().await?;

// Initialize WAL mode on first use (idempotent and safe)
if !self.wal_initialized.load(Ordering::SeqCst) {
// Initialize WAL mode on first use (atomic check-and-set)
if self
.wal_initialized
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
sqlx::query("PRAGMA journal_mode = WAL")
.execute(&mut *conn)
.await?;
Expand All @@ -260,8 +264,6 @@ impl SqliteDatabase {
sqlx::query("PRAGMA synchronous = NORMAL")
.execute(&mut *conn)
.await?;

self.wal_initialized.store(true, Ordering::SeqCst);
}

// Return WriteGuard wrapping the pool connection
Expand Down
20 changes: 11 additions & 9 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct TransactionToken {
pub transaction_id: String,
}

/// Actions that can be taken on a pausable transaction
/// Actions that can be taken on an interruptible transaction
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum TransactionAction {
Expand Down Expand Up @@ -264,12 +264,18 @@ pub async fn close_all(db_instances: State<'_, DbInstances>) -> Result<()> {
// Collect all wrappers to close
let wrappers: Vec<DatabaseWrapper> = instances.drain().map(|(_, v)| v).collect();

// Close each connection
// Close each connection, continuing on errors to ensure all get closed
let mut last_error = None;
for wrapper in wrappers {
wrapper.close().await?;
if let Err(e) = wrapper.close().await {
last_error = Some(e);
}
}

Ok(())
match last_error {
Some(e) => Err(e),
None => Ok(()),
}
}

/// Close database connection and remove all database files
Expand Down Expand Up @@ -343,12 +349,8 @@ pub async fn execute_interruptible_transaction(
q.execute(&mut *writer).await?;
}

// Create abort handle for transaction cleanup on app exit
let abort_handle = tokio::spawn(std::future::pending::<()>()).abort_handle();

// Store transaction state
let tx =
ActiveInterruptibleTransaction::new(db.clone(), transaction_id.clone(), writer, abort_handle);
let tx = ActiveInterruptibleTransaction::new(db.clone(), transaction_id.clone(), writer);

active_txs.insert(db.clone(), tx).await?;

Expand Down
3 changes: 3 additions & 0 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use crate::Error;
///
/// This function handles the type conversion from SQLite's native types
/// to JSON-compatible representations.
///
/// Note: BLOB values are returned as base64-encoded strings since JSON
/// has no native binary type. Boolean values are stored as INTEGER in SQLite.
pub fn to_json(value: SqliteValueRef) -> Result<JsonValue, Error> {
if value.is_null() {
return Ok(JsonValue::Null);
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub struct MigrationEvent {
pub db_path: String,
/// Status: "running", "completed", "failed"
pub status: String,
/// Total number of migrations in the migrator (on "completed"), not just newly applied
/// Total number of migrations defined in the migrator (on "completed"), not just newly applied
#[serde(skip_serializing_if = "Option::is_none")]
pub migration_count: Option<usize>,
/// Error message (on "failed")
Expand Down
22 changes: 4 additions & 18 deletions src/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;

use indexmap::IndexMap;
use serde::Deserialize;
Expand All @@ -20,23 +19,14 @@ pub struct ActiveInterruptibleTransaction {
db_path: String,
transaction_id: String,
writer: WriteGuard,
abort_handle: AbortHandle,
created_at: Instant,
}

impl ActiveInterruptibleTransaction {
pub fn new(
db_path: String,
transaction_id: String,
writer: WriteGuard,
abort_handle: AbortHandle,
) -> Self {
pub fn new(db_path: String, transaction_id: String, writer: WriteGuard) -> Self {
Self {
db_path,
transaction_id,
writer,
abort_handle,
created_at: Instant::now(),
}
}

Expand All @@ -48,10 +38,6 @@ impl ActiveInterruptibleTransaction {
&self.transaction_id
}

pub fn created_at(&self) -> Instant {
self.created_at
}

pub fn validate_token(&self, token_id: &str) -> Result<()> {
if self.transaction_id != token_id {
return Err(Error::InvalidTransactionToken);
Expand Down Expand Up @@ -157,15 +143,15 @@ impl ActiveInterruptibleTransactions {
let mut txs = self.0.write().await;
debug!("Aborting {} active interruptible transaction(s)", txs.len());

for (db_path, tx) in txs.iter() {
for db_path in txs.keys() {
debug!(
"Aborting interruptible transaction for database: {}",
"Dropping interruptible transaction for database: {}",
db_path
);
tx.abort_handle.abort();
}

// Clear all transactions to drop WriteGuards and release locks
// Dropping triggers auto-rollback via Drop trait
txs.clear();
}

Expand Down
13 changes: 7 additions & 6 deletions src/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ impl DatabaseWrapper {
}

/// Execute a SELECT query expecting zero or one result
///
/// Returns an error if the query returns more than one row.
pub async fn fetch_one(
&self,
query: String,
Expand All @@ -185,11 +187,7 @@ impl DatabaseWrapper {
// Use read pool for queries
let pool = self.inner.read_pool()?;

// Add LIMIT 2 to detect if query returns multiple rows
// We only need to fetch up to 2 rows to know if there's more than 1
let limited_query = format!("{} LIMIT 2", query.trim_end_matches(';'));

let mut q = sqlx::query(&limited_query);
let mut q = sqlx::query(&query);
for value in values {
q = bind_value(q, value);
}
Expand Down Expand Up @@ -274,7 +272,10 @@ pub(crate) fn bind_value<'a>(
}
}

/// Resolve database file path relative to app config directory
/// Resolve database file path relative to app config directory.
///
/// Paths are joined to `app_config_dir()` (e.g., `Library/Application Support/${bundleIdentifier}` on iOS).
/// Special paths like `:memory:` are passed through unchanged.
fn resolve_database_path<R: Runtime>(path: &str, app: &AppHandle<R>) -> Result<PathBuf, Error> {
let app_path = app
.path()
Expand Down