diff --git a/packages/sqlite_async/lib/src/common/sqlite_database.dart b/packages/sqlite_async/lib/src/common/sqlite_database.dart index 3cb12bb..3201135 100644 --- a/packages/sqlite_async/lib/src/common/sqlite_database.dart +++ b/packages/sqlite_async/lib/src/common/sqlite_database.dart @@ -39,6 +39,14 @@ mixin SqliteDatabaseMixin implements SqliteConnection, SqliteQueries { /// /// Use this to access the database in background isolates. IsolateConnectionFactory isolateConnectionFactory(); + + /// Locks all underlying connections making up this database, and gives [block] access to all of them at once. + /// This can be useful to run the same statement on all connections. For instance, + /// ATTACHing a database, that is expected to be available in all connections. + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block); } /// A SQLite database instance. diff --git a/packages/sqlite_async/lib/src/impl/single_connection_database.dart b/packages/sqlite_async/lib/src/impl/single_connection_database.dart index 4cd3144..7ca4357 100644 --- a/packages/sqlite_async/lib/src/impl/single_connection_database.dart +++ b/packages/sqlite_async/lib/src/impl/single_connection_database.dart @@ -57,4 +57,12 @@ final class SingleConnectionDatabase return connection.writeLock(callback, lockTimeout: lockTimeout, debugContext: debugContext); } + + @override + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block) { + return writeLock((_) => block(connection, [])); + } } diff --git a/packages/sqlite_async/lib/src/impl/stub_sqlite_database.dart b/packages/sqlite_async/lib/src/impl/stub_sqlite_database.dart index 29db641..ee254f3 100644 --- a/packages/sqlite_async/lib/src/impl/stub_sqlite_database.dart +++ b/packages/sqlite_async/lib/src/impl/stub_sqlite_database.dart @@ -64,4 +64,12 @@ class SqliteDatabaseImpl Future getAutoCommit() { throw UnimplementedError(); } + + @override + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block) { + throw UnimplementedError(); + } } diff --git a/packages/sqlite_async/lib/src/native/database/connection_pool.dart b/packages/sqlite_async/lib/src/native/database/connection_pool.dart index 9521b34..8dab27e 100644 --- a/packages/sqlite_async/lib/src/native/database/connection_pool.dart +++ b/packages/sqlite_async/lib/src/native/database/connection_pool.dart @@ -31,6 +31,8 @@ class SqliteConnectionPool with SqliteQueries implements SqliteConnection { final MutexImpl mutex; + int _runningWithAllConnectionsCount = 0; + @override bool closed = false; @@ -88,6 +90,14 @@ class SqliteConnectionPool with SqliteQueries implements SqliteConnection { return; } + if (_availableReadConnections.isEmpty && + _runningWithAllConnectionsCount > 0) { + // Wait until [withAllConnections] is done. Otherwise we could spawn a new + // reader while the user is configuring all the connections, + // e.g. a global open factory configuration shared across all connections. + return; + } + var nextItem = _queue.removeFirst(); while (nextItem.completer.isCompleted) { // This item already timed out - try the next one if available @@ -232,6 +242,66 @@ class SqliteConnectionPool with SqliteQueries implements SqliteConnection { await connection.refreshSchema(); } } + + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block) async { + try { + _runningWithAllConnectionsCount++; + + final blockCompleter = Completer(); + final (write, reads) = await _lockAllConns(blockCompleter); + + try { + final res = await block(write, reads); + blockCompleter.complete(res); + return res; + } catch (e, st) { + blockCompleter.completeError(e, st); + rethrow; + } + } finally { + _runningWithAllConnectionsCount--; + + // Continue processing any pending read requests that may have been queued while + // the block was running. + Timer.run(_nextRead); + } + } + + /// Locks all connections, returning the acquired contexts. + /// We pass a completer that would be called after the locks are taken. + Future<(SqliteWriteContext, List)> _lockAllConns( + Completer lockCompleter) async { + final List> readLockedCompleters = []; + final Completer writeLockedCompleter = Completer(); + + // Take the write lock + writeLock((ctx) { + writeLockedCompleter.complete(ctx); + return lockCompleter.future; + }); + + // Take all the read locks + for (final readConn in _allReadConnections) { + final completer = Completer(); + readLockedCompleters.add(completer); + + readConn.readLock((ctx) { + completer.complete(ctx); + return lockCompleter.future; + }); + } + + // Wait after all locks are taken + final [writer as SqliteWriteContext, ...readers] = await Future.wait([ + writeLockedCompleter.future, + ...readLockedCompleters.map((e) => e.future) + ]); + + return (writer, readers); + } } typedef ReadCallback = Future Function(SqliteReadContext tx); diff --git a/packages/sqlite_async/lib/src/native/database/native_sqlite_database.dart b/packages/sqlite_async/lib/src/native/database/native_sqlite_database.dart index 7bea111..22cacf3 100644 --- a/packages/sqlite_async/lib/src/native/database/native_sqlite_database.dart +++ b/packages/sqlite_async/lib/src/native/database/native_sqlite_database.dart @@ -171,4 +171,12 @@ class SqliteDatabaseImpl Future refreshSchema() { return _pool.refreshSchema(); } + + @override + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block) { + return _pool.withAllConnections(block); + } } diff --git a/packages/sqlite_async/lib/src/web/database.dart b/packages/sqlite_async/lib/src/web/database.dart index cfaf987..f2dc998 100644 --- a/packages/sqlite_async/lib/src/web/database.dart +++ b/packages/sqlite_async/lib/src/web/database.dart @@ -171,6 +171,14 @@ class WebDatabase await isInitialized; return _database.fileSystem.flush(); } + + @override + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block) { + return writeLock((_) => block(this, [])); + } } final class _UnscopedContext extends UnscopedContext { diff --git a/packages/sqlite_async/lib/src/web/database/web_sqlite_database.dart b/packages/sqlite_async/lib/src/web/database/web_sqlite_database.dart index c6d1b75..69f01ab 100644 --- a/packages/sqlite_async/lib/src/web/database/web_sqlite_database.dart +++ b/packages/sqlite_async/lib/src/web/database/web_sqlite_database.dart @@ -178,4 +178,12 @@ class SqliteDatabaseImpl Future exposeEndpoint() async { return await _connection.exposeEndpoint(); } + + @override + Future withAllConnections( + Future Function( + SqliteWriteContext writer, List readers) + block) { + return writeLock((_) => block(_connection, [])); + } } diff --git a/packages/sqlite_async/test/basic_test.dart b/packages/sqlite_async/test/basic_test.dart index 6a315da..e2914b3 100644 --- a/packages/sqlite_async/test/basic_test.dart +++ b/packages/sqlite_async/test/basic_test.dart @@ -7,6 +7,7 @@ import 'utils/test_utils_impl.dart'; final testUtils = TestUtils(); const _isDart2Wasm = bool.fromEnvironment('dart.tool.dart2wasm'); +const _isWeb = identical(0, 0.0) || _isDart2Wasm; void main() { group('Shared Basic Tests', () { @@ -301,6 +302,49 @@ void main() { 'Web locks are managed with a shared worker, which does not support timeouts', ) }); + + test('with all connections', () async { + final maxReaders = _isWeb ? 0 : 3; + + final db = SqliteDatabase.withFactory( + await testUtils.testFactory(path: path), + maxReaders: maxReaders, + ); + await db.initialize(); + await createTables(db); + + // Warm up to spawn the max readers + await Future.wait([for (var i = 0; i < 10; i++) db.get('SELECT $i')]); + + bool finishedWithAllConns = false; + + late Future readsCalledWhileWithAllConnsRunning; + + final parentZone = Zone.current; + await db.withAllConnections((writer, readers) async { + expect(readers.length, maxReaders); + + // Run some reads during the block that they should run after the block finishes and releases + // all locks + // Need a root zone here to avoid recursive lock errors. + readsCalledWhileWithAllConnsRunning = + Future(parentZone.bindCallback(() async { + await Future.wait( + [1, 2, 3, 4, 5, 6, 7, 8].map((i) async { + await db.readLock((c) async { + expect(finishedWithAllConns, isTrue); + await Future.delayed(const Duration(milliseconds: 100)); + }); + }), + ); + })); + + await Future.delayed(const Duration(milliseconds: 200)); + finishedWithAllConns = true; + }); + + await readsCalledWhileWithAllConnsRunning; + }); }); } diff --git a/packages/sqlite_async/test/native/basic_test.dart b/packages/sqlite_async/test/native/basic_test.dart index dec1fed..3f348e6 100644 --- a/packages/sqlite_async/test/native/basic_test.dart +++ b/packages/sqlite_async/test/native/basic_test.dart @@ -2,12 +2,16 @@ library; import 'dart:async'; +import 'dart:io'; import 'dart:math'; +import 'package:collection/collection.dart'; +import 'package:path/path.dart' show join; import 'package:sqlite3/common.dart' as sqlite; import 'package:sqlite_async/sqlite_async.dart'; import 'package:test/test.dart'; +import '../utils/abstract_test_utils.dart'; import '../utils/test_utils_impl.dart'; final testUtils = TestUtils(); @@ -100,6 +104,126 @@ void main() { print("${DateTime.now()} done"); }); + test('prevent opening new readers while in withAllConnections', () async { + final sharedStateDir = Directory.systemTemp.createTempSync(); + addTearDown(() => sharedStateDir.deleteSync(recursive: true)); + + final File sharedStateFile = + File(join(sharedStateDir.path, 'shared-state.txt')); + + sharedStateFile.writeAsStringSync('initial'); + + final db = SqliteDatabase.withFactory( + _TestSqliteOpenFactoryWithSharedStateFile( + path: path, sharedStateFilePath: sharedStateFile.path), + maxReaders: 3); + await db.initialize(); + await createTables(db); + + // The writer saw 'initial' in the file when opening the connection + expect( + await db + .writeLock((c) => c.get('SELECT file_contents_on_open() AS state')), + {'state': 'initial'}, + ); + + final withAllConnectionsCompleter = Completer(); + + final withAllConnsFut = db.withAllConnections((writer, readers) async { + expect(readers.length, 0); // No readers yet + + // Simulate some work until the file is updated + await Future.delayed(const Duration(milliseconds: 200)); + sharedStateFile.writeAsStringSync('updated'); + + await withAllConnectionsCompleter.future; + }); + + // Start a reader that gets the contents of the shared file + bool readFinished = false; + final someReadFut = + db.get('SELECT file_contents_on_open() AS state', []).then((r) { + readFinished = true; + return r; + }); + + // The withAllConnections should prevent the reader from opening + await Future.delayed(const Duration(milliseconds: 100)); + expect(readFinished, isFalse); + + // Free all the locks + withAllConnectionsCompleter.complete(); + await withAllConnsFut; + + final readerInfo = await someReadFut; + expect(readFinished, isTrue); + // The read should see the updated value in the file. This checks + // that a reader doesn't spawn while running withAllConnections + expect(readerInfo, {'state': 'updated'}); + }); + + test('with all connections', () async { + final maxReaders = 3; + + final db = SqliteDatabase.withFactory( + await testUtils.testFactory(path: path), + maxReaders: maxReaders, + ); + await db.initialize(); + await createTables(db); + + Future readWithRandomDelay( + SqliteReadContext ctx, int id) async { + return await ctx.get( + 'SELECT ? as i, test_sleep(?) as sleep, test_connection_name() as connection', + [id, 5 + Random().nextInt(10)]); + } + + // Warm up to spawn the max readers + await Future.wait( + [1, 2, 3, 4, 5, 6, 7, 8].map((i) => readWithRandomDelay(db, i)), + ); + + bool finishedWithAllConns = false; + + late Future readsCalledWhileWithAllConnsRunning; + + print("${DateTime.now()} start"); + await db.withAllConnections((writer, readers) async { + expect(readers.length, maxReaders); + + // Run some reads during the block that they should run after the block finishes and releases + // all locks + readsCalledWhileWithAllConnsRunning = Future.wait( + [1, 2, 3, 4, 5, 6, 7, 8].map((i) async { + final r = await db.readLock((c) async { + expect(finishedWithAllConns, isTrue); + return await readWithRandomDelay(c, i); + }); + print( + "${DateTime.now()} After withAllConnections, started while running $r"); + }), + ); + + await Future.wait([ + writer.execute( + "INSERT OR REPLACE INTO test_data(id, description) SELECT ? as i, test_sleep(?) || ' ' || test_connection_name() || ' 1 ' || datetime() as connection RETURNING *", + [ + 123, + 5 + Random().nextInt(20) + ]).then((value) => + print("${DateTime.now()} withAllConnections writer done $value")), + ...readers + .mapIndexed((i, r) => readWithRandomDelay(r, i).then((results) { + print( + "${DateTime.now()} withAllConnections readers done $results"); + })) + ]); + }).then((_) => finishedWithAllConns = true); + + await readsCalledWhileWithAllConnsRunning; + }); + test('read-only transactions', () async { final db = await testUtils.setupDatabase(path: path); await createTables(db); @@ -379,3 +503,31 @@ class _InvalidPragmaOnOpenFactory extends DefaultSqliteOpenFactory { ]; } } + +class _TestSqliteOpenFactoryWithSharedStateFile + extends TestDefaultSqliteOpenFactory { + final String sharedStateFilePath; + + _TestSqliteOpenFactoryWithSharedStateFile( + {required super.path, required this.sharedStateFilePath}); + + @override + sqlite.CommonDatabase open(SqliteOpenOptions options) { + final File sharedStateFile = File(sharedStateFilePath); + final String sharedState = sharedStateFile.readAsStringSync(); + + final db = super.open(options); + + // Function to return the contents of the shared state file at the time of opening + // so that we know at which point the factory was called. + db.createFunction( + functionName: 'file_contents_on_open', + argumentCount: const sqlite.AllowedArgumentCount(0), + function: (args) { + return sharedState; + }, + ); + + return db; + } +}