Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 37 additions & 32 deletions Sources/SQLite/SQLiteDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
public static let suspendNotification = GRDB.Database.suspendNotification
public static let resumeNotification = GRDB.Database.resumeNotification

public static let unicodeCompare =
GRDB.DatabaseCollation.unicodeCompare.name
public static let caseInsensitiveCompare =
GRDB.DatabaseCollation.caseInsensitiveCompare.name
public static let localizedCaseInsensitiveCompare =
GRDB.DatabaseCollation.localizedCaseInsensitiveCompare.name
public static let localizedCompare =
GRDB.DatabaseCollation.localizedCompare.name
public static let localizedStandardCompare =
GRDB.DatabaseCollation.localizedStandardCompare.name

public let path: String
public let sqliteVersion: String

Expand All @@ -33,7 +44,8 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {

public static func makeShared(
path: String,
busyTimeout: TimeInterval = 5
busyTimeout: TimeInterval = 5,
collationSequences: [String: @Sendable (String, String) -> ComparisonResult] = [:]
) throws -> SQLiteDatabase {
guard path != ":memory:" else {
throw SQLiteError.SQLITE_IOERR
Expand All @@ -58,7 +70,8 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
do {
database = try SQLiteDatabase(
path: url.path,
busyTimeout: busyTimeout
busyTimeout: busyTimeout,
collationSequences: collationSequences
)
} catch {
databaseError = error
Expand All @@ -79,8 +92,16 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
return db
}

public init(path: String = ":memory:", busyTimeout: TimeInterval = 5) throws {
database = try Self.open(at: path, busyTimeout: busyTimeout)
public init(
path: String = ":memory:",
busyTimeout: TimeInterval = 5,
collationSequences: [String: @Sendable (String, String) -> ComparisonResult] = [:]
) throws {
database = try Self.open(
at: path,
busyTimeout: busyTimeout,
collationSequences: collationSequences
)
self.path = path
let sqliteVersion = try Self.getSQLiteVersion(database)
self.sqliteVersion = sqliteVersion.description
Expand Down Expand Up @@ -587,33 +608,6 @@ public extension SQLiteDatabase {
}
}

// MARK: - Collating sequences

public extension SQLiteDatabase {
func addCollation(
named name: String,
comparator: @escaping @Sendable (String, String) -> ComparisonResult
) throws {
let collation = DatabaseCollation(
name,
function: comparator
)
try database
.writer
.barrierWriteWithoutTransaction { $0.add(collation: collation) }
}

func removeCollation(named name: String) throws {
let collation = DatabaseCollation(
name,
function: { _, _ in .orderedSame }
)
try database
.writer
.barrierWriteWithoutTransaction { $0.remove(collation: collation) }
}
}

// MARK: - Pragmas

public extension SQLiteDatabase {
Expand Down Expand Up @@ -748,7 +742,8 @@ extension SQLiteDatabase {
private extension SQLiteDatabase {
class func open(
at path: String,
busyTimeout: TimeInterval
busyTimeout: TimeInterval,
collationSequences: [String: @Sendable (String, String) -> ComparisonResult]
) throws -> Database {
let isInMemory: Bool = {
let p = path.lowercased()
Expand All @@ -763,6 +758,16 @@ private extension SQLiteDatabase {
ProcessInfo.processInfo.processorCount,
6
)
if !collationSequences.isEmpty {
config.prepareDatabase { db in
for (name, comparator) in collationSequences {
db.add(collation: DatabaseCollation(
name,
function: comparator
))
}
}
}

guard !isInMemory else {
do {
Expand Down
116 changes: 10 additions & 106 deletions Tests/SQLiteTests/SQLiteDatabaseTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -94,112 +94,10 @@ final class SQLiteDatabaseTests: XCTestCase {
}
}

func testAddAndRemoveCollation() throws {
struct Entity: Hashable, SQLiteTransformable {
let id: String
let string: String?

init(_ id: Int, _ string: String? = nil) {
self.id = String(id)
self.string = string
}

init(row: SQLiteRow) throws {
id = try row.value(for: "id")
string = row.optionalValue(for: "string")
}

var asArguments: SQLiteArguments {
[
"id": .text(id),
"string": string.map { .text($0) } ?? .null,
]
}
}

let apple = Entity(1, "Apple")
let banana = Entity(2, "banana")
let zebra = Entity(3, "Zebra")
let null1 = Entity(4)
let null2 = Entity(5)

try database.inTransaction { db in
try db.write(_createTableWithIDAsStringAndNullableString)
try [apple, banana, zebra, null1, null2]
.forEach { entity in
try db.write(
_insertIDAndString,
arguments: entity.asArguments
)
}
}

let selectDefaultSorted: SQL = """
SELECT * FROM test ORDER BY string;
"""

let selectCustomCaseSensitiveSorted: SQL = """
SELECT * FROM test ORDER BY string COLLATE CUSTOM;
"""

let selectCustomCaseInsensitiveSorted: SQL = """
SELECT * FROM test ORDER BY string COLLATE CUSTOM_NOCASE;
"""

let defaultSorted: [Entity] = try database.read(selectDefaultSorted)
XCTAssertEqual(
defaultSorted,
[null1, null2, apple, zebra, banana]
)

XCTAssertThrowsError(
try database.read(selectCustomCaseSensitiveSorted)
) { error in
guard case SQLiteError.SQLITE_ERROR_MISSING_COLLSEQ = error else {
XCTFail("Should have thrown SQLITE_ERROR")
return
}
}

try database.addCollation(named: "CUSTOM") { $0.compare($1) }
let customSorted: [Entity] = try database.read(selectCustomCaseSensitiveSorted)
XCTAssertEqual(
customSorted,
[null1, null2, apple, zebra, banana]
)

try database.addCollation(
named: "CUSTOM_NOCASE"
) { $0.caseInsensitiveCompare($1) }

let customNoCaseSorted: [Entity] = try database
.read(selectCustomCaseInsensitiveSorted)
XCTAssertEqual(
customNoCaseSorted,
[null1, null2, apple, banana, zebra]
)

try database.removeCollation(named: "CUSTOM_NOCASE")
XCTAssertThrowsError(
try database.read(selectCustomCaseInsensitiveSorted)
) { error in
guard case SQLiteError.SQLITE_ERROR_MISSING_COLLSEQ = error else {
XCTFail("Should have thrown SQLITE_ERROR")
return
}
}
let customSortedAfterRemovingNoCase: [Entity] = try database
.read(selectCustomCaseSensitiveSorted)
XCTAssertEqual(
customSortedAfterRemovingNoCase,
[null1, null2, apple, zebra, banana]
)
}

func testCustomLocalizedCollation() throws {
try database.addCollation(named: "LOCALIZED") { lhs, rhs in
lhs.localizedStandardCompare(rhs)
}
database = try SQLiteDatabase(collationSequences: [
"CUSTOM_LOCALIZED": { $0.localizedStandardCompare($1) },
])

// NOTE: ([toInsert], [binary sort], [localized sort])
let cases: [([String], [String], [String])] = [
Expand Down Expand Up @@ -252,10 +150,16 @@ final class SQLiteDatabaseTests: XCTestCase {
XCTAssertEqual(binarySorted, binarySort)

let localizedSorted: [String] = try database
.read("SELECT * FROM test ORDER BY string COLLATE LOCALIZED;")
.read("SELECT * FROM test ORDER BY string COLLATE CUSTOM_LOCALIZED;")
.compactMap { $0["string"]?.stringValue }
XCTAssertEqual(localizedSorted, localizedSort)

let grdbName = SQLiteDatabase.localizedStandardCompare
let grdbStandardSorted: [String] = try database
.read("SELECT * FROM test ORDER BY string COLLATE \(grdbName);")
.compactMap { $0["string"]?.stringValue }
XCTAssertEqual(grdbStandardSorted, localizedSort)

try database.write("DROP TABLE test;")
}
}
Expand Down