Skip to content

Commit c94ea3d

Browse files
committed
Fix adding custom collation sequences
1 parent 223e7d9 commit c94ea3d

File tree

2 files changed

+53
-138
lines changed

2 files changed

+53
-138
lines changed

Sources/SQLite/SQLiteDatabase.swift

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
99
public static let suspendNotification = GRDB.Database.suspendNotification
1010
public static let resumeNotification = GRDB.Database.resumeNotification
1111

12+
public static let unicodeCompare =
13+
GRDB.DatabaseCollation.unicodeCompare.name
14+
public static let caseInsensitiveCompare =
15+
GRDB.DatabaseCollation.caseInsensitiveCompare.name
16+
public static let localizedCaseInsensitiveCompare =
17+
GRDB.DatabaseCollation.localizedCaseInsensitiveCompare.name
18+
public static let localizedCompare =
19+
GRDB.DatabaseCollation.localizedCompare.name
20+
public static let localizedStandardCompare =
21+
GRDB.DatabaseCollation.localizedStandardCompare.name
22+
1223
public let path: String
1324
public let sqliteVersion: String
1425

@@ -33,7 +44,10 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
3344

3445
public static func makeShared(
3546
path: String,
36-
busyTimeout: TimeInterval = 5
47+
busyTimeout: TimeInterval = 5,
48+
collationSequences: [
49+
String: @Sendable (String, String) -> ComparisonResult
50+
] = [:]
3751
) throws -> SQLiteDatabase {
3852
guard path != ":memory:" else {
3953
throw SQLiteError.SQLITE_IOERR
@@ -58,7 +72,8 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
5872
do {
5973
database = try SQLiteDatabase(
6074
path: url.path,
61-
busyTimeout: busyTimeout
75+
busyTimeout: busyTimeout,
76+
collationSequences: collationSequences
6277
)
6378
} catch {
6479
databaseError = error
@@ -79,8 +94,18 @@ public final class SQLiteDatabase: DatabaseProtocol, @unchecked Sendable {
7994
return db
8095
}
8196

82-
public init(path: String = ":memory:", busyTimeout: TimeInterval = 5) throws {
83-
database = try Self.open(at: path, busyTimeout: busyTimeout)
97+
public init(
98+
path: String = ":memory:",
99+
busyTimeout: TimeInterval = 5,
100+
collationSequences: [
101+
String: @Sendable (String, String) -> ComparisonResult,
102+
] = [:]
103+
) throws {
104+
database = try Self.open(
105+
at: path,
106+
busyTimeout: busyTimeout,
107+
collationSequences: collationSequences
108+
)
84109
self.path = path
85110
let sqliteVersion = try Self.getSQLiteVersion(database)
86111
self.sqliteVersion = sqliteVersion.description
@@ -587,33 +612,6 @@ public extension SQLiteDatabase {
587612
}
588613
}
589614

590-
// MARK: - Collating sequences
591-
592-
public extension SQLiteDatabase {
593-
func addCollation(
594-
named name: String,
595-
comparator: @escaping @Sendable (String, String) -> ComparisonResult
596-
) throws {
597-
let collation = DatabaseCollation(
598-
name,
599-
function: comparator
600-
)
601-
try database
602-
.writer
603-
.barrierWriteWithoutTransaction { $0.add(collation: collation) }
604-
}
605-
606-
func removeCollation(named name: String) throws {
607-
let collation = DatabaseCollation(
608-
name,
609-
function: { _, _ in .orderedSame }
610-
)
611-
try database
612-
.writer
613-
.barrierWriteWithoutTransaction { $0.remove(collation: collation) }
614-
}
615-
}
616-
617615
// MARK: - Pragmas
618616

619617
public extension SQLiteDatabase {
@@ -748,7 +746,10 @@ extension SQLiteDatabase {
748746
private extension SQLiteDatabase {
749747
class func open(
750748
at path: String,
751-
busyTimeout: TimeInterval
749+
busyTimeout: TimeInterval,
750+
collationSequences: [
751+
String: @Sendable (String, String) -> ComparisonResult
752+
]
752753
) throws -> Database {
753754
let isInMemory: Bool = {
754755
let p = path.lowercased()
@@ -763,6 +764,16 @@ private extension SQLiteDatabase {
763764
ProcessInfo.processInfo.processorCount,
764765
6
765766
)
767+
if !collationSequences.isEmpty {
768+
config.prepareDatabase { db in
769+
for (name, comparator) in collationSequences {
770+
db.add(collation: DatabaseCollation(
771+
name,
772+
function: comparator
773+
))
774+
}
775+
}
776+
}
766777

767778
guard !isInMemory else {
768779
do {

Tests/SQLiteTests/SQLiteDatabaseTests.swift

Lines changed: 10 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -94,112 +94,10 @@ final class SQLiteDatabaseTests: XCTestCase {
9494
}
9595
}
9696

97-
func testAddAndRemoveCollation() throws {
98-
struct Entity: Hashable, SQLiteTransformable {
99-
let id: String
100-
let string: String?
101-
102-
init(_ id: Int, _ string: String? = nil) {
103-
self.id = String(id)
104-
self.string = string
105-
}
106-
107-
init(row: SQLiteRow) throws {
108-
id = try row.value(for: "id")
109-
string = row.optionalValue(for: "string")
110-
}
111-
112-
var asArguments: SQLiteArguments {
113-
[
114-
"id": .text(id),
115-
"string": string.map { .text($0) } ?? .null,
116-
]
117-
}
118-
}
119-
120-
let apple = Entity(1, "Apple")
121-
let banana = Entity(2, "banana")
122-
let zebra = Entity(3, "Zebra")
123-
let null1 = Entity(4)
124-
let null2 = Entity(5)
125-
126-
try database.inTransaction { db in
127-
try db.write(_createTableWithIDAsStringAndNullableString)
128-
try [apple, banana, zebra, null1, null2]
129-
.forEach { entity in
130-
try db.write(
131-
_insertIDAndString,
132-
arguments: entity.asArguments
133-
)
134-
}
135-
}
136-
137-
let selectDefaultSorted: SQL = """
138-
SELECT * FROM test ORDER BY string;
139-
"""
140-
141-
let selectCustomCaseSensitiveSorted: SQL = """
142-
SELECT * FROM test ORDER BY string COLLATE CUSTOM;
143-
"""
144-
145-
let selectCustomCaseInsensitiveSorted: SQL = """
146-
SELECT * FROM test ORDER BY string COLLATE CUSTOM_NOCASE;
147-
"""
148-
149-
let defaultSorted: [Entity] = try database.read(selectDefaultSorted)
150-
XCTAssertEqual(
151-
defaultSorted,
152-
[null1, null2, apple, zebra, banana]
153-
)
154-
155-
XCTAssertThrowsError(
156-
try database.read(selectCustomCaseSensitiveSorted)
157-
) { error in
158-
guard case SQLiteError.SQLITE_ERROR_MISSING_COLLSEQ = error else {
159-
XCTFail("Should have thrown SQLITE_ERROR")
160-
return
161-
}
162-
}
163-
164-
try database.addCollation(named: "CUSTOM") { $0.compare($1) }
165-
let customSorted: [Entity] = try database.read(selectCustomCaseSensitiveSorted)
166-
XCTAssertEqual(
167-
customSorted,
168-
[null1, null2, apple, zebra, banana]
169-
)
170-
171-
try database.addCollation(
172-
named: "CUSTOM_NOCASE"
173-
) { $0.caseInsensitiveCompare($1) }
174-
175-
let customNoCaseSorted: [Entity] = try database
176-
.read(selectCustomCaseInsensitiveSorted)
177-
XCTAssertEqual(
178-
customNoCaseSorted,
179-
[null1, null2, apple, banana, zebra]
180-
)
181-
182-
try database.removeCollation(named: "CUSTOM_NOCASE")
183-
XCTAssertThrowsError(
184-
try database.read(selectCustomCaseInsensitiveSorted)
185-
) { error in
186-
guard case SQLiteError.SQLITE_ERROR_MISSING_COLLSEQ = error else {
187-
XCTFail("Should have thrown SQLITE_ERROR")
188-
return
189-
}
190-
}
191-
let customSortedAfterRemovingNoCase: [Entity] = try database
192-
.read(selectCustomCaseSensitiveSorted)
193-
XCTAssertEqual(
194-
customSortedAfterRemovingNoCase,
195-
[null1, null2, apple, zebra, banana]
196-
)
197-
}
198-
19997
func testCustomLocalizedCollation() throws {
200-
try database.addCollation(named: "LOCALIZED") { lhs, rhs in
201-
lhs.localizedStandardCompare(rhs)
202-
}
98+
database = try SQLiteDatabase(collationSequences: [
99+
"CUSTOM_LOCALIZED": { $0.localizedStandardCompare($1) },
100+
])
203101

204102
// NOTE: ([toInsert], [binary sort], [localized sort])
205103
let cases: [([String], [String], [String])] = [
@@ -252,10 +150,16 @@ final class SQLiteDatabaseTests: XCTestCase {
252150
XCTAssertEqual(binarySorted, binarySort)
253151

254152
let localizedSorted: [String] = try database
255-
.read("SELECT * FROM test ORDER BY string COLLATE LOCALIZED;")
153+
.read("SELECT * FROM test ORDER BY string COLLATE CUSTOM_LOCALIZED;")
256154
.compactMap { $0["string"]?.stringValue }
257155
XCTAssertEqual(localizedSorted, localizedSort)
258156

157+
let grdbName = SQLiteDatabase.localizedStandardCompare
158+
let grdbStandardSorted: [String] = try database
159+
.read("SELECT * FROM test ORDER BY string COLLATE \(grdbName);")
160+
.compactMap { $0["string"]?.stringValue }
161+
XCTAssertEqual(grdbStandardSorted, localizedSort)
162+
259163
try database.write("DROP TABLE test;")
260164
}
261165
}

0 commit comments

Comments
 (0)