// // Copyright (c) 2012 Krueger Systems, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; namespace SQLite { public partial class SQLiteAsyncConnection { SQLiteConnectionString _connectionString; SQLiteOpenFlags _openFlags; public SQLiteAsyncConnection(string databasePath, bool storeDateTimeAsTicks = false) : this(databasePath, SQLiteOpenFlags.ReadWrite | SQLiteOpenFlags.Create, storeDateTimeAsTicks) { } public SQLiteAsyncConnection(string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks = false) { _openFlags = openFlags; _connectionString = new SQLiteConnectionString(databasePath, storeDateTimeAsTicks); } SQLiteConnectionWithLock GetConnection () { return SQLiteConnectionPool.Shared.GetConnection (_connectionString, _openFlags); } public Task CreateTableAsync () where T : new () { return CreateTablesAsync (typeof (T)); } public Task CreateTablesAsync () where T : new () where T2 : new () { return CreateTablesAsync (typeof (T), typeof (T2)); } public Task CreateTablesAsync () where T : new () where T2 : new () where T3 : new () { return CreateTablesAsync (typeof (T), typeof (T2), typeof (T3)); } public Task CreateTablesAsync () where T : new () where T2 : new () where T3 : new () where T4 : new () { return CreateTablesAsync (typeof (T), typeof (T2), typeof (T3), typeof (T4)); } public Task CreateTablesAsync () where T : new () where T2 : new () where T3 : new () where T4 : new () where T5 : new () { return CreateTablesAsync (typeof (T), typeof (T2), typeof (T3), typeof (T4), typeof (T5)); } public Task CreateTablesAsync (params Type[] types) { return Task.Factory.StartNew (() => { CreateTablesResult result = new CreateTablesResult (); var conn = GetConnection (); using (conn.Lock ()) { foreach (Type type in types) { int aResult = conn.CreateTable (type); result.Results[type] = aResult; } } return result; }); } public Task DropTableAsync () where T : new () { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.DropTable (); } }); } public Task InsertAsync (object item) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Insert (item); } }); } public Task UpdateAsync (object item) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Update (item); } }); } public Task DeleteAsync (object item) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Delete (item); } }); } public Task GetAsync(object pk) where T : new() { return Task.Factory.StartNew(() => { var conn = GetConnection(); using (conn.Lock()) { return conn.Get(pk); } }); } public Task FindAsync (object pk) where T : new () { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Find (pk); } }); } public Task GetAsync (Expression> predicate) where T : new() { return Task.Factory.StartNew(() => { var conn = GetConnection(); using (conn.Lock()) { return conn.Get (predicate); } }); } public Task FindAsync (Expression> predicate) where T : new () { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Find (predicate); } }); } public Task ExecuteAsync (string query, params object[] args) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Execute (query, args); } }); } public Task InsertAllAsync (IEnumerable items) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.InsertAll (items); } }); } public Task UpdateAllAsync (IEnumerable items) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.UpdateAll (items); } }); } [Obsolete("Will cause a deadlock if any call in action ends up in a different thread. Use RunInTransactionAsync(Action) instead.")] public Task RunInTransactionAsync (Action action) { return Task.Factory.StartNew (() => { var conn = this.GetConnection (); using (conn.Lock ()) { conn.BeginTransaction (); try { action (this); conn.Commit (); } catch (Exception) { conn.Rollback (); throw; } } }); } public Task RunInTransactionAsync(Action action) { return Task.Factory.StartNew(() => { var conn = this.GetConnection(); using (conn.Lock()) { conn.BeginTransaction(); try { action(conn); conn.Commit(); } catch (Exception) { conn.Rollback(); throw; } } }); } public AsyncTableQuery Table () where T : new () { // // This isn't async as the underlying connection doesn't go out to the database // until the query is performed. The Async methods are on the query iteself. // var conn = GetConnection (); return new AsyncTableQuery (conn.Table ()); } public Task ExecuteScalarAsync (string sql, params object[] args) { return Task.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { var command = conn.CreateCommand (sql, args); return command.ExecuteScalar (); } }); } public Task> QueryAsync (string sql, params object[] args) where T : new () { return Task>.Factory.StartNew (() => { var conn = GetConnection (); using (conn.Lock ()) { return conn.Query (sql, args); } }); } } // // TODO: Bind to AsyncConnection.GetConnection instead so that delayed // execution can still work after a Pool.Reset. // public class AsyncTableQuery where T : new () { TableQuery _innerQuery; public AsyncTableQuery (TableQuery innerQuery) { _innerQuery = innerQuery; } public AsyncTableQuery Where (Expression> predExpr) { return new AsyncTableQuery (_innerQuery.Where (predExpr)); } public AsyncTableQuery Skip (int n) { return new AsyncTableQuery (_innerQuery.Skip (n)); } public AsyncTableQuery Take (int n) { return new AsyncTableQuery (_innerQuery.Take (n)); } public AsyncTableQuery OrderBy (Expression> orderExpr) { return new AsyncTableQuery (_innerQuery.OrderBy (orderExpr)); } public AsyncTableQuery OrderByDescending (Expression> orderExpr) { return new AsyncTableQuery (_innerQuery.OrderByDescending (orderExpr)); } public Task> ToListAsync () { return Task.Factory.StartNew (() => { using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { return _innerQuery.ToList (); } }); } public Task CountAsync () { return Task.Factory.StartNew (() => { using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { return _innerQuery.Count (); } }); } public Task ElementAtAsync (int index) { return Task.Factory.StartNew (() => { using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { return _innerQuery.ElementAt (index); } }); } public Task FirstAsync () { return Task.Factory.StartNew(() => { using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { return _innerQuery.First (); } }); } public Task FirstOrDefaultAsync () { return Task.Factory.StartNew(() => { using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { return _innerQuery.FirstOrDefault (); } }); } } public class CreateTablesResult { public Dictionary Results { get; private set; } internal CreateTablesResult () { this.Results = new Dictionary (); } } class SQLiteConnectionPool { class Entry { public SQLiteConnectionString ConnectionString { get; private set; } public SQLiteConnectionWithLock Connection { get; private set; } public Entry (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) { ConnectionString = connectionString; Connection = new SQLiteConnectionWithLock (connectionString, openFlags); } public void OnApplicationSuspended () { Connection.Dispose (); Connection = null; } } readonly Dictionary _entries = new Dictionary (); readonly object _entriesLock = new object (); static readonly SQLiteConnectionPool _shared = new SQLiteConnectionPool (); /// /// Gets the singleton instance of the connection tool. /// public static SQLiteConnectionPool Shared { get { return _shared; } } public SQLiteConnectionWithLock GetConnection (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) { lock (_entriesLock) { Entry entry; string key = connectionString.ConnectionString; if (!_entries.TryGetValue (key, out entry)) { entry = new Entry (connectionString, openFlags); _entries[key] = entry; } return entry.Connection; } } /// /// Closes all connections managed by this pool. /// public void Reset () { lock (_entriesLock) { foreach (var entry in _entries.Values) { entry.OnApplicationSuspended (); } _entries.Clear (); } } /// /// Call this method when the application is suspended. /// /// Behaviour here is to close any open connections. public void ApplicationSuspended () { Reset (); } } class SQLiteConnectionWithLock : SQLiteConnection { readonly object _lockPoint = new object (); public SQLiteConnectionWithLock (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) : base (connectionString.DatabasePath, openFlags, connectionString.StoreDateTimeAsTicks) { } public IDisposable Lock () { return new LockWrapper (_lockPoint); } private class LockWrapper : IDisposable { object _lockPoint; public LockWrapper (object lockPoint) { _lockPoint = lockPoint; Monitor.Enter (_lockPoint); } public void Dispose () { Monitor.Exit (_lockPoint); } } } }