Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 85 additions & 44 deletions PowerSync/PowerSync.Common/Client/PowerSyncDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ public class PowerSyncDatabase : EventStream<PowerSyncDBEvent>, IPowerSyncDataba
protected IBucketStorageAdapter BucketStorageAdapter;

protected CancellationTokenSource? syncStreamStatusCts;
protected CancellationTokenSource watchSubscriptionCts = new();

protected SyncStatus CurrentStatus;

Expand Down Expand Up @@ -365,6 +366,16 @@ public async Task Connect(IPowerSyncBackendConnector connector, PowerSyncConnect
await syncStreamImplementation.Connect(options);
}

/// <summary>
/// Unsubscribe from all currently watched queries.
/// </summary>
public void UnsubscribeAllQueries()
{
watchSubscriptionCts.Cancel();
watchSubscriptionCts.Dispose();
watchSubscriptionCts = new();
}

public async Task Disconnect()
{
await WaitForReady();
Expand Down Expand Up @@ -415,7 +426,7 @@ await Database.WriteTransaction(async tx =>

if (Closed) return;


UnsubscribeAllQueries();
await Disconnect();
base.Close();
syncStreamImplementation?.Close();
Expand Down Expand Up @@ -671,64 +682,61 @@ public async Task<T> WriteTransaction<T>(Func<ITransaction, Task<T>> fn, DBLockO
/// Use <see cref="SQLWatchOptions.ThrottleMs"/> to specify the minimum interval between queries.
/// Source tables are automatically detected using <c>EXPLAIN QUERY PLAN</c>.
/// </summary>
public Task Watch<T>(string query, object?[]? parameters, WatchHandler<T> handler, SQLWatchOptions? options = null)
=> WatchInternal(query, parameters, handler, options, GetAll<T>);
public Task<IDisposable> Watch<T>(string query, object?[]? parameters, WatchHandler<T> handler, SQLWatchOptions? options = null)
=> Task.Run(() => WatchInternal(query, parameters, handler, options, GetAll<T>));

/// <summary>
/// Executes a read query every time the source tables are modified.
/// <para />
/// Use <see cref="SQLWatchOptions.ThrottleMs"/> to specify the minimum interval between queries.
/// Source tables are automatically detected using <c>EXPLAIN QUERY PLAN</c>.
/// </summary>
public Task Watch(string query, object?[]? parameters, WatchHandler<dynamic> handler, SQLWatchOptions? options = null)
=> WatchInternal(query, parameters, handler, options, GetAll);
public Task<IDisposable> Watch(string query, object?[]? parameters, WatchHandler<dynamic> handler, SQLWatchOptions? options = null)
=> Task.Run(() => WatchInternal(query, parameters, handler, options, GetAll));

private Task WatchInternal<T>(
private async Task<IDisposable> WatchInternal<T>(
string query,
object?[]? parameters,
WatchHandler<T> handler,
SQLWatchOptions? options,
Func<string, object?[]?, Task<T[]>> getter
)
{
var tcs = new TaskCompletionSource<bool>();
Task.Run(async () =>
try
{
try
{
var resolvedTables = await ResolveTables(query, parameters, options);
var result = await getter(query, parameters);
handler.OnResult(result);
var resolvedTables = await ResolveTables(query, parameters, options);
var result = await getter(query, parameters);
handler.OnResult(result);

OnChange(new WatchOnChangeHandler
var subscription = OnChange(new WatchOnChangeHandler
{
OnChange = async (change) =>
{
OnChange = async (change) =>
try
{
try
{
var result = await getter(query, parameters);
handler.OnResult(result);
}
catch (Exception ex)
{
handler.OnError?.Invoke(ex);
}
},
OnError = handler.OnError
}, new SQLWatchOptions
{
Tables = resolvedTables,
Signal = options?.Signal,
ThrottleMs = options?.ThrottleMs
});
tcs.SetResult(true);
}
catch (Exception ex)
var result = await getter(query, parameters);
handler.OnResult(result);
}
catch (Exception ex)
{
handler.OnError?.Invoke(ex);
}
},
OnError = handler.OnError
}, new SQLWatchOptions
{
handler.OnError?.Invoke(ex);
}
});
return tcs.Task;
Tables = resolvedTables,
Signal = options?.Signal,
ThrottleMs = options?.ThrottleMs
});

return subscription;
}
catch (Exception ex)
{
handler.OnError?.Invoke(ex);
throw;
}
}

private class ExplainedResult
Expand Down Expand Up @@ -776,7 +784,7 @@ public async Task<string[]> ResolveTables(string sql, object?[]? parameters = nu
/// This is preferred over <see cref="Watch"/> when multiple queries need to be performed
/// together in response to data changes.
/// </summary>
public void OnChange(WatchOnChangeHandler handler, SQLWatchOptions? options = null)
public IDisposable OnChange(WatchOnChangeHandler handler, SQLWatchOptions? options = null)
{
var resolvedOptions = options ?? new SQLWatchOptions();

Expand Down Expand Up @@ -811,13 +819,27 @@ void flushTableUpdates()
}
});

CancellationTokenSource linkedCts;
if (options?.Signal.HasValue == true)
{
options.Signal.Value.Register(() =>
{
cts.Cancel();
});
// Cancel on global CTS cancellation or user token cancellation
linkedCts = CancellationTokenSource.CreateLinkedTokenSource(
watchSubscriptionCts.Token,
options.Signal.Value
);
}
else
{
// Cancel on global CTS cancellation
linkedCts = watchSubscriptionCts;
}

var registration = linkedCts.Token.Register(() =>
{
cts.Cancel();
});

return new WatchSubscription(cts, registration);
}

private static void HandleTableChanges(HashSet<string> changedTables, HashSet<string> watchedTables, Action<string[]> onDetectedChanges)
Expand Down Expand Up @@ -879,3 +901,22 @@ public class WatchOnChangeHandler
public Func<WatchOnChangeEvent, Task> OnChange { get; set; } = null!;
public Action<Exception>? OnError { get; set; }
}

public class WatchSubscription(CancellationTokenSource cts, CancellationTokenRegistration registration) : IDisposable
{
private readonly CancellationTokenSource _cts = cts;
private readonly CancellationTokenRegistration _registration = registration;
private bool _disposed;

public bool Disposed { get { return _disposed; } }

public void Dispose()
{
if (_disposed) return;
_disposed = true;

_registration.Dispose();
_cts.Cancel();
_cts.Dispose();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -568,4 +568,107 @@ await tx.Execute(

await watched.Task;
}

[Fact(Timeout = 2000)]
public async void WatchDisposableSubscriptionTest()
{
int callCount = 0;

var subscription = await db.Watch("select id, description, make from assets", null, new()
{
OnResult = (results) => callCount++,
OnError = (ex) => Assert.Fail("An exception occurred: " + ex.ToString())
});
Thread.Sleep(200);
Assert.Equal(1, callCount);

// Bump callCount to 2
await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
Thread.Sleep(200);
Assert.Equal(2, callCount);

subscription.Dispose();
await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
Thread.Sleep(200);
Assert.Equal(2, callCount);
}

[Fact(Timeout = 2500)]
public async void WatchDisposableCustomTokenTest()
{
var customTokenSource = new CancellationTokenSource();
int callCount = 0;

using var subscription = await db.Watch("select id, description, make from assets", null, new()
{
OnResult = (results) => callCount++,
OnError = (ex) => Assert.Fail("An exception occurred: " + ex.ToString())
}, new()
{
Signal = customTokenSource.Token
});
Thread.Sleep(200);
Assert.Equal(1, callCount);

await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
Thread.Sleep(200);
Assert.Equal(2, callCount);

customTokenSource.Cancel();
await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
Thread.Sleep(200);
Assert.Equal(2, callCount); // Same value
}

[Fact(Timeout = 2000)]
public async void WatchMultipleCancelledTest()
{
int callCount = 0;
var watchHandlerFactory = () => new WatchHandler<IdResult>
{
OnResult = (result) => callCount++,
OnError = (ex) => Assert.Fail("An exception occurred: " + ex.ToString()),
};

var query1 = await db.Watch("select id from assets", null, watchHandlerFactory());
var query2 = await db.Watch("select id from customers", null, watchHandlerFactory());
Thread.Sleep(200);
Assert.Equal(2, callCount);

await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
Thread.Sleep(200);
Assert.Equal(4, callCount);

db.UnsubscribeAllQueries();

await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
await db.Execute(
"insert into assets(id, description, make) values (?, ?, ?)",
[Guid.NewGuid().ToString(), "some desc", "some make"]
);
Thread.Sleep(200);
Assert.Equal(4, callCount);
}
}