Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ export interface SqliteDatabase {
): Promise<SqliteExecuteResult>;
run(sql: string, params?: SqliteBindings): Promise<void>;
query(sql: string, params?: SqliteBindings): Promise<SqliteQueryResult>;
writeMode<T>(callback: () => Promise<T>): Promise<T>;
nativeMetrics?():
| SqliteNativeMetrics
| Promise<SqliteNativeMetrics | null>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import type {
import { db } from "./mod";

class FakeSqliteDatabase implements SqliteDatabase {
writeModeDepth = 0;
executeCalls: {
sql: string;
params?: SqliteBindings;
writeMode: boolean;
}[] = [];

async exec(): Promise<void> {}
Expand All @@ -24,7 +22,6 @@ class FakeSqliteDatabase implements SqliteDatabase {
this.executeCalls.push({
sql,
params,
writeMode: this.writeModeDepth > 0,
});
return {
columns: [],
Expand All @@ -43,15 +40,6 @@ class FakeSqliteDatabase implements SqliteDatabase {
return { columns, rows };
}

async writeMode<T>(callback: () => Promise<T>): Promise<T> {
this.writeModeDepth++;
try {
return await callback();
} finally {
this.writeModeDepth--;
}
}

async close(): Promise<void> {}
}

Expand All @@ -73,7 +61,7 @@ function testProviderContext(
}

describe("db", () => {
test("runs onMigrate through sqlite write mode", async () => {
test("runs onMigrate inside a sqlite savepoint", async () => {
const nativeDb = new FakeSqliteDatabase();
const provider = db({
onMigrate: async (client) => {
Expand All @@ -90,15 +78,59 @@ describe("db", () => {
await provider.onMigrate(client);

expect(nativeDb.executeCalls).toEqual([
{
sql: "SAVEPOINT __rivet_on_migrate",
params: undefined,
},
{
sql: "CREATE TABLE items(id INTEGER PRIMARY KEY, value TEXT)",
params: undefined,
writeMode: true,
},
{
sql: "SELECT COUNT(*) AS count FROM items",
params: undefined,
writeMode: true,
},
{
sql: "RELEASE SAVEPOINT __rivet_on_migrate",
params: undefined,
},
]);
});

test("rolls back the migration savepoint when onMigrate fails", async () => {
const nativeDb = new FakeSqliteDatabase();
const provider = db({
onMigrate: async (client) => {
await client.execute(
"CREATE TABLE items(id INTEGER PRIMARY KEY, value TEXT)",
);
throw new Error("migration failed");
},
});
const client = await provider.createClient(
testProviderContext(nativeDb),
);

await expect(provider.onMigrate(client)).rejects.toThrow(
"migration failed",
);

expect(nativeDb.executeCalls).toEqual([
{
sql: "SAVEPOINT __rivet_on_migrate",
params: undefined,
},
{
sql: "CREATE TABLE items(id INTEGER PRIMARY KEY, value TEXT)",
params: undefined,
},
{
sql: "ROLLBACK TO SAVEPOINT __rivet_on_migrate",
params: undefined,
},
{
sql: "RELEASE SAVEPOINT __rivet_on_migrate",
params: undefined,
},
]);
});
Expand Down
35 changes: 15 additions & 20 deletions rivetkit-typescript/packages/rivetkit/src/common/database/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ interface DatabaseFactoryConfig {
onMigrate?: (db: RawAccess) => Promise<void> | void;
}

type RawAccessWithWriteMode = RawAccess & {
__rivetWriteMode: <T>(callback: () => Promise<T> | T) => Promise<T>;
};

function hasMultipleStatements(query: string): boolean {
const trimmed = query.trim().replace(/;+$/, "").trimEnd();
return trimmed.includes(";");
Expand Down Expand Up @@ -38,7 +34,7 @@ export function db({
}
};

const client: RawAccessWithWriteMode = {
const client: RawAccess = {
execute: async <
TRow extends Record<string, unknown> = Record<
string,
Expand Down Expand Up @@ -103,17 +99,12 @@ export function db({
}
},
nativeMetrics: () => db.nativeMetrics?.() ?? null,
__rivetWriteMode: async <T>(
callback: () => Promise<T> | T,
): Promise<T> => {
return await db.writeMode(async () => await callback());
},
};
return client;
},
onMigrate: async (client) => {
if (onMigrate) {
await dbWriteMode(client, () => onMigrate(client));
await withMigrationSavepoint(client, () => onMigrate(client));
}
},
};
Expand Down Expand Up @@ -145,17 +136,21 @@ async function execMultiStatement<TRow extends Record<string, unknown>>(
return results as TRow[];
}

async function dbWriteMode<T>(
async function withMigrationSavepoint<T>(
client: RawAccess,
callback: () => Promise<T> | T,
): Promise<T> {
const maybeClient = client as RawAccess & {
__rivetWriteMode?: <TInner>(
callback: () => Promise<TInner> | TInner,
) => Promise<TInner>;
};
if (maybeClient.__rivetWriteMode) {
return await maybeClient.__rivetWriteMode(callback);
await client.execute("SAVEPOINT __rivet_on_migrate");
try {
const result = await callback();
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
return result;
} catch (error) {
try {
await client.execute("ROLLBACK TO SAVEPOINT __rivet_on_migrate");
} finally {
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
}
throw error;
}
return await callback();
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,28 +104,6 @@ describe("wrapJsNativeDatabase", () => {
});
});

test("keeps write mode on the normal native execute lane", async () => {
const native = new FakeNativeDatabase();
const db = wrapJsNativeDatabase(native);

const query = db.writeMode(async () => {
const promise = db.query("SELECT 1");
expect(native.executeCalls).toMatchObject([
{ sql: "SELECT 1", write: false },
]);
native.resolveNext({
columns: ["value"],
rows: [[1]],
});
return await promise;
});

await expect(query).resolves.toEqual({
columns: ["value"],
rows: [[1]],
});
});

test("normalizes supported sqlite bind values", async () => {
const native = new FakeNativeDatabase();
const db = wrapJsNativeDatabase(native);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,6 @@ export function wrapJsNativeDatabase(
const { columns, rows } = await executeNative(sql, params);
return { columns, rows };
},
async writeMode<T>(callback: () => Promise<T>): Promise<T> {
return await callback();
},
nativeMetrics(): SqliteNativeMetrics | null {
return normalizeNativeMetrics(database.metrics?.());
},
Expand Down
35 changes: 17 additions & 18 deletions rivetkit-typescript/packages/rivetkit/src/db/drizzle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,14 @@ export function db<TSchema extends DrizzleSchema = Record<string, never>>({
await nativeDb.close();
}
};
(
drizzleDb as DrizzleDatabase<TSchema> & {
__rivetWriteMode: <T>(
callback: () => Promise<T> | T,
) => Promise<T>;
}
).__rivetWriteMode = async (callback) =>
await nativeDb.writeMode(async () => await callback());

return drizzleDb;
},
onMigrate: async (client) => {
await dbWriteMode(client, async () => {
if (!migrations && !onMigrate) {
return;
}
await withMigrationSavepoint(client, async () => {
if (migrations) {
await runMigrations(client, migrations);
}
Expand All @@ -187,19 +182,23 @@ export function db<TSchema extends DrizzleSchema = Record<string, never>>({
};
}

async function dbWriteMode<T>(
async function withMigrationSavepoint<T>(
client: RawAccess,
callback: () => Promise<T> | T,
): Promise<T> {
const maybeClient = client as RawAccess & {
__rivetWriteMode?: <TInner>(
callback: () => Promise<TInner> | TInner,
) => Promise<TInner>;
};
if (maybeClient.__rivetWriteMode) {
return await maybeClient.__rivetWriteMode(callback);
await client.execute("SAVEPOINT __rivet_on_migrate");
try {
const result = await callback();
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
return result;
} catch (error) {
try {
await client.execute("ROLLBACK TO SAVEPOINT __rivet_on_migrate");
} finally {
await client.execute("RELEASE SAVEPOINT __rivet_on_migrate");
}
throw error;
}
return await callback();
}

async function runMigrations<TSchema extends DrizzleSchema>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ ${wasmModuleSource}
interface SqliteDatabase {
\trun(sql: string, params?: unknown[]): Promise<void>;
\tquery(sql: string, params?: unknown[]): Promise<{ rows: unknown[][] }>;
\twriteMode<T>(callback: () => Promise<T>): Promise<T>;
}

interface RegistryConfig {
Expand All @@ -137,29 +136,23 @@ const rawSqlDatabaseProvider = {
};

async function ensureCounterTable(db: SqliteDatabase) {
\tawait db.writeMode(async () => {
\t\tawait db.run(
\t\t\t"CREATE TABLE IF NOT EXISTS platform_counter (id INTEGER PRIMARY KEY CHECK (id = 1), count INTEGER NOT NULL)",
\t\t);
\t});
\tawait db.run(
\t\t"CREATE TABLE IF NOT EXISTS platform_counter (id INTEGER PRIMARY KEY CHECK (id = 1), count INTEGER NOT NULL)",
\t);
}

async function ensureLifecycleTable(db: SqliteDatabase) {
\tawait db.writeMode(async () => {
\t\tawait db.run(
\t\t\t"CREATE TABLE IF NOT EXISTS platform_counter_lifecycle (event TEXT PRIMARY KEY, count INTEGER NOT NULL)",
\t\t);
\t});
\tawait db.run(
\t\t"CREATE TABLE IF NOT EXISTS platform_counter_lifecycle (event TEXT PRIMARY KEY, count INTEGER NOT NULL)",
\t);
}

async function recordLifecycleEvent(db: SqliteDatabase, event: string) {
\tawait ensureLifecycleTable(db);
\tawait db.writeMode(async () => {
\t\tawait db.run(
\t\t\t"INSERT INTO platform_counter_lifecycle (event, count) VALUES (?, 1) ON CONFLICT(event) DO UPDATE SET count = count + 1",
\t\t\t[event],
\t\t);
\t});
\tawait db.run(
\t\t"INSERT INTO platform_counter_lifecycle (event, count) VALUES (?, 1) ON CONFLICT(event) DO UPDATE SET count = count + 1",
\t\t[event],
\t);
}

async function readCounter(db: SqliteDatabase): Promise<number> {
Expand Down Expand Up @@ -201,12 +194,10 @@ const sqliteCounter = actor({
\t\tincrement: async (ctx, amount = 1) => {
\t\t\tconst db = ctx.sql as SqliteDatabase;
\t\t\tawait ensureCounterTable(db);
\t\t\tawait db.writeMode(async () => {
\t\t\t\tawait db.run(
\t\t\t\t\t"INSERT INTO platform_counter (id, count) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET count = count + excluded.count",
\t\t\t\t\t[COUNTER_ID, amount],
\t\t\t\t);
\t\t\t});
\t\t\tawait db.run(
\t\t\t\t"INSERT INTO platform_counter (id, count) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET count = count + excluded.count",
\t\t\t\t[COUNTER_ID, amount],
\t\t\t);

\t\t\treturn await readCounter(db);
\t\t},
Expand Down
Loading
Loading