From 4f241917b600aa67c637329ac7713f59a908a30f Mon Sep 17 00:00:00 2001 From: Master Kwoth Date: Sun, 7 Feb 2016 21:51:05 +0100 Subject: [PATCH] Holy SQL. --- NadekoBot/Classes/DBHandler.cs | 68 + NadekoBot/Classes/NadekoStats.cs | 48 +- .../Classes/_DataModels/AnnouncementModel.cs | 21 + NadekoBot/Classes/_DataModels/CommandModel.cs | 19 + NadekoBot/Classes/_DataModels/IDataModel.cs | 14 + NadekoBot/Classes/_DataModels/RequestModel.cs | 17 + NadekoBot/Classes/_DataModels/StatsModel.cs | 16 + .../Classes/_DataModels/TypingArticleModel.cs | 13 + NadekoBot/Commands/RequestsCommand.cs | 68 +- NadekoBot/Commands/ServerGreetCommand.cs | 107 +- NadekoBot/Commands/SpeedTyping.cs | 12 +- NadekoBot/NadekoBot.cs | 4 +- NadekoBot/NadekoBot.csproj | 15 +- NadekoBot/SQLite.cs | 3278 +++++++++++++++++ NadekoBot/SQLiteAsync.cs | 503 +++ NadekoBot/packages.config | 1 + 16 files changed, 4048 insertions(+), 156 deletions(-) create mode 100644 NadekoBot/Classes/DBHandler.cs create mode 100644 NadekoBot/Classes/_DataModels/AnnouncementModel.cs create mode 100644 NadekoBot/Classes/_DataModels/CommandModel.cs create mode 100644 NadekoBot/Classes/_DataModels/IDataModel.cs create mode 100644 NadekoBot/Classes/_DataModels/RequestModel.cs create mode 100644 NadekoBot/Classes/_DataModels/StatsModel.cs create mode 100644 NadekoBot/Classes/_DataModels/TypingArticleModel.cs create mode 100644 NadekoBot/SQLite.cs create mode 100644 NadekoBot/SQLiteAsync.cs diff --git a/NadekoBot/Classes/DBHandler.cs b/NadekoBot/Classes/DBHandler.cs new file mode 100644 index 00000000..91812e8e --- /dev/null +++ b/NadekoBot/Classes/DBHandler.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.IO; +using SQLite; +using NadekoBot.Classes._DataModels; + +namespace NadekoBot.Classes { + class DBHandler { + private static readonly DBHandler _instance = new DBHandler(); + public static DBHandler Instance => _instance; + + private string _filePath { get; } = "data/nadekobot.sqlite"; + + static DBHandler() { } + public DBHandler() { + using (var _conn = new SQLiteConnection(_filePath)) { + _conn.CreateTable(); + _conn.CreateTable(); + _conn.CreateTable(); + _conn.CreateTable(); + _conn.CreateTable(); + } + } + + internal void InsertData(T o) where T : IDataModel { + using (var _conn = new SQLiteConnection(_filePath)) { + _conn.Insert(o, typeof(T)); + } + } + + internal void UpdateData(T o) where T : IDataModel { + using (var _conn = new SQLiteConnection(_filePath)) { + _conn.Update(o, typeof(T)); + } + } + + internal List GetAllRows() where T : IDataModel, new() { + using (var _conn = new SQLiteConnection(_filePath)) { + return _conn.Table().Where(t => true).ToList(); + } + } + + internal T Delete(int Id) where T : IDataModel, new() { + using (var _conn = new SQLiteConnection(_filePath)) { + var found = _conn.Table().Where(t => t.Id == Id).FirstOrDefault(); + if (found != null) + _conn.Delete(found.Id); + return found; + } + } + + /// + /// Updates an existing object or creates a new one + /// + internal void Save(T o) where T : IDataModel, new() { + using (var _conn = new SQLiteConnection(_filePath)) { + var found = _conn.Table().Where(t => t.Id == o.Id).FirstOrDefault(); + if (found == null) + _conn.Insert(o, typeof(T)); + else + _conn.Update(o, typeof(T)); + } + } + } +} diff --git a/NadekoBot/Classes/NadekoStats.cs b/NadekoBot/Classes/NadekoStats.cs index 9ccca97e..521e6d46 100644 --- a/NadekoBot/Classes/NadekoStats.cs +++ b/NadekoBot/Classes/NadekoStats.cs @@ -40,6 +40,9 @@ namespace NadekoBot Console.WriteLine("Logging enabled."); } + public TimeSpan GetUptime() => + DateTime.Now - Process.GetCurrentProcess().StartTime; + public string GetUptimeString() { var time = (DateTime.Now - Process.GetCurrentProcess().StartTime); return time.Days + " days, " + time.Hours + " hours, and " + time.Minutes + " minutes."; @@ -72,18 +75,22 @@ namespace NadekoBot private async Task StartCollecting() { while (true) { - await Task.Delay(new TimeSpan(1, 0, 0)); + await Task.Delay(new TimeSpan(0, 30, 0)); try { - var obj = new ParseObject("Stats"); - obj["OnlineUsers"] = await Task.Run(() => NadekoBot.client.Servers.Sum(x => x.Users.Count())); - obj["RealOnlineUsers"] = await Task.Run(() => NadekoBot - .client.Servers + var onlineUsers = await Task.Run(() => NadekoBot.client.Servers.Sum(x => x.Users.Count())); + var realOnlineUsers = await Task.Run(() => NadekoBot.client.Servers .Sum(x => x.Users.Where(u => u.Status == UserStatus.Online).Count())); - obj["ConnectedServers"] = NadekoBot.client.Servers.Count(); + var connectedServers = NadekoBot.client.Servers.Count(); - await obj.SaveAsync(); + Classes.DBHandler.Instance.InsertData(new Classes._DataModels.Stats { + OnlineUsers = onlineUsers, + RealOnlineUsers = realOnlineUsers, + Uptime = GetUptime(), + ConnectedServers = connectedServers, + DateAdded = DateTime.Now + }); } catch (Exception) { - Console.WriteLine("Parse exception in StartCollecting"); + Console.WriteLine("DB Exception in stats collecting."); break; } } @@ -93,22 +100,19 @@ namespace NadekoBot { try { _commandsRan++; - var obj = new ParseObject("CommandsRan"); - - obj["ServerId"] = e.Server.Id; - obj["ServerName"] = e.Server.Name; - - obj["ChannelId"] = e.Channel.Id; - obj["ChannelName"] = e.Channel.Name; - - obj["UserId"] = e.User.Id; - obj["UserName"] = e.User.Name; - - obj["CommandName"] = e.Command.Text; - obj.SaveAsync(); + Classes.DBHandler.Instance.InsertData(new Classes._DataModels.Command { + ServerId = (long)e.Server.Id, + ServerName = e.Server.Name, + ChannelId = (long)e.Channel.Id, + ChannelName =e.Channel.Name, + UserId = (long)e.User.Id, + UserName = e.User.Name, + CommandName = e.Command.Text, + DateAdded = DateTime.Now + }); } catch (Exception) { Console.WriteLine("Parse error in ran command."); } } } -} +} \ No newline at end of file diff --git a/NadekoBot/Classes/_DataModels/AnnouncementModel.cs b/NadekoBot/Classes/_DataModels/AnnouncementModel.cs new file mode 100644 index 00000000..25db6842 --- /dev/null +++ b/NadekoBot/Classes/_DataModels/AnnouncementModel.cs @@ -0,0 +1,21 @@ +using SQLite; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NadekoBot.Classes._DataModels { + class Announcement : IDataModel { + public long ServerId { get; set; } = 0; + public bool Greet { get; set; } = false; + public bool GreetPM { get; set; } = false; + public long GreetChannelId { get; set; } = 0; + public string GreetText { get; set; } = "Welcome %user%!"; + public bool Bye { get; set; } = false; + public bool ByePM { get; set; } = false; + public long ByeChannelId { get; set; } = 0; + public string ByeText { get; set; } = "%user% has left the server."; + public DateTime DateAdded { get; set; } = DateTime.Now; + } +} diff --git a/NadekoBot/Classes/_DataModels/CommandModel.cs b/NadekoBot/Classes/_DataModels/CommandModel.cs new file mode 100644 index 00000000..d035a7e9 --- /dev/null +++ b/NadekoBot/Classes/_DataModels/CommandModel.cs @@ -0,0 +1,19 @@ +using SQLite; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NadekoBot.Classes._DataModels { + class Command : IDataModel { + public long UserId { get; set; } + public string UserName { get; set; } + public long ServerId { get; set; } + public string ServerName { get; set; } + public long ChannelId { get; set; } + public string ChannelName { get; set; } + public string CommandName { get; set; } + public DateTime DateAdded { get; set; } + } +} diff --git a/NadekoBot/Classes/_DataModels/IDataModel.cs b/NadekoBot/Classes/_DataModels/IDataModel.cs new file mode 100644 index 00000000..891cc04e --- /dev/null +++ b/NadekoBot/Classes/_DataModels/IDataModel.cs @@ -0,0 +1,14 @@ +using SQLite; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NadekoBot.Classes._DataModels { + class IDataModel { + [PrimaryKey, AutoIncrement] + public int Id { get; set; } + public IDataModel() { } + } +} diff --git a/NadekoBot/Classes/_DataModels/RequestModel.cs b/NadekoBot/Classes/_DataModels/RequestModel.cs new file mode 100644 index 00000000..3f70b868 --- /dev/null +++ b/NadekoBot/Classes/_DataModels/RequestModel.cs @@ -0,0 +1,17 @@ +using SQLite; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NadekoBot.Classes._DataModels { + class Request : IDataModel { + public string UserName { get; set; } + public long UserId { get; set; } + public string ServerName { get; set; } + public long ServerId { get; set; } + public string RequestText { get; set; } + public DateTime DateAdded { get; set; } + } +} diff --git a/NadekoBot/Classes/_DataModels/StatsModel.cs b/NadekoBot/Classes/_DataModels/StatsModel.cs new file mode 100644 index 00000000..e7e3869e --- /dev/null +++ b/NadekoBot/Classes/_DataModels/StatsModel.cs @@ -0,0 +1,16 @@ +using SQLite; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NadekoBot.Classes._DataModels { + class Stats : IDataModel { + public int ConnectedServers { get; set; } + public int OnlineUsers { get; set; } + public TimeSpan Uptime { get; set; } + public int RealOnlineUsers { get; set; } + public DateTime DateAdded { get; set; } + } +} diff --git a/NadekoBot/Classes/_DataModels/TypingArticleModel.cs b/NadekoBot/Classes/_DataModels/TypingArticleModel.cs new file mode 100644 index 00000000..385d682f --- /dev/null +++ b/NadekoBot/Classes/_DataModels/TypingArticleModel.cs @@ -0,0 +1,13 @@ +using SQLite; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace NadekoBot.Classes._DataModels { + class TypingArticle : IDataModel { + public string Text { get; set; } + public DateTime DateAdded { get; set; } + } +} diff --git a/NadekoBot/Commands/RequestsCommand.cs b/NadekoBot/Commands/RequestsCommand.cs index 321260eb..55275b52 100644 --- a/NadekoBot/Commands/RequestsCommand.cs +++ b/NadekoBot/Commands/RequestsCommand.cs @@ -1,76 +1,42 @@ using System; using System.Threading.Tasks; using Discord.Commands; -using Parse; using NadekoBot.Extensions; namespace NadekoBot.Commands { class RequestsCommand : DiscordCommand { public void SaveRequest(CommandEventArgs e, string text) { - - var obj = new ParseObject("Requests"); - - obj["ServerId"] = e.Server.Id; - obj["ServerName"] = e.Server.Name; - obj["UserId"] = e.User.Id; - obj["UserName"] = e.User.Name; - obj["Request"] = text; - - obj.SaveAsync(); + Classes.DBHandler.Instance.InsertData(new Classes._DataModels.Request { + RequestText = text, + UserName = e.User.Name, + UserId = (long)e.User.Id, + ServerId = (long)e.Server.Id, + ServerName = e.Server.Name, + DateAdded = DateTime.Now + }); } // todo what if it's too long? public string GetRequests() { - var task = ParseObject.GetQuery("Requests") - .FindAsync().Result; + var task = Classes.DBHandler.Instance.GetAllRows(); string str = "Here are all current requests for NadekoBot:\n\n"; int i = 1; foreach (var reqObj in task) { - str += (i++) + ". by **" + reqObj["UserName"] + "** from **" + reqObj["ServerName"] + "** at " + reqObj.CreatedAt.Value.ToLocalTime() + "\n"; - str += "**" + reqObj["Request"] + "**\n----------\n"; + str += $"{reqObj.Id}. by **{reqObj.UserName}** from **{reqObj.ServerName}** at {reqObj.DateAdded.ToLocalTime()}\n" + + $"**{reqObj.RequestText}**\n----------\n"; } return str + "\n__Type [@NadekoBot clr] to clear all of my messages.__"; } - public bool DeleteRequest(int requestNumber) { - var task = ParseObject.GetQuery("Requests") - .FindAsync().Result; - int i = 1; - foreach (var reqObj in task) { - if (i == requestNumber) { - reqObj.DeleteAsync(); - return true; - } - i++; - } - return false; - } + public bool DeleteRequest(int requestNumber) => + Classes.DBHandler.Instance.Delete(requestNumber) != null; - public class ResolveRequestObject { - public ulong Id; - public ulong ServerId; - public string Text; - } /// - /// Resolves a request with a number and returns that users id. + /// Delete a request with a number and returns that request object. /// /// RequestObject of the request. Null if none - public ResolveRequestObject ResolveRequest(int requestNumber) { - var task = ParseObject.GetQuery("Requests") - .FindAsync().Result; - int i = 1; - foreach (var reqObj in task) { - if (i == requestNumber) { - var txt = reqObj.Get("Request"); - var id = reqObj.Get("UserId"); - var sid = reqObj.Get("ServerId"); - reqObj.DeleteAsync(); - return new ResolveRequestObject { Id = id, Text = txt, ServerId = sid }; - } - i++; - } - return null; - } + public Classes._DataModels.Request ResolveRequest(int requestNumber) => + Classes.DBHandler.Instance.Delete(requestNumber); public override Func DoFunc() { throw new NotImplementedException(); @@ -130,7 +96,7 @@ namespace NadekoBot.Commands { var sc = ResolveRequest(int.Parse(e.Args[0])); if (sc != null) { await e.Send(e.User.Mention + " Request resolved, notice sent."); - await client.GetServer(sc.ServerId).GetUser(sc.Id).Send("**This request of yours has been resolved:**\n" + sc.Text); + await client.GetServer((ulong)sc.ServerId).GetUser((ulong)sc.UserId).Send("**This request of yours has been resolved:**\n" + sc.RequestText); } else { await e.Send("No request on that number."); } diff --git a/NadekoBot/Commands/ServerGreetCommand.cs b/NadekoBot/Commands/ServerGreetCommand.cs index e7a71fd5..d249f0d1 100644 --- a/NadekoBot/Commands/ServerGreetCommand.cs +++ b/NadekoBot/Commands/ServerGreetCommand.cs @@ -36,12 +36,11 @@ namespace NadekoBot.Commands { NadekoBot.client.UserJoined += UserJoined; NadekoBot.client.UserLeft += UserLeft; - var data = new ParseQuery("Announcements") - .FindAsync() - .Result; + List data = Classes.DBHandler.Instance.GetAllRows(); + if (data.Any()) - foreach (var po in data) - AnnouncementsDictionary.TryAdd(po.Get("serverId"), new AnnounceControls(po.Get("serverId")).Initialize(po)); + foreach (var obj in data) + AnnouncementsDictionary.TryAdd((ulong)obj.ServerId, new AnnounceControls(obj)); } private async void UserLeft(object sender, UserEventArgs e) { @@ -85,67 +84,59 @@ namespace NadekoBot.Commands { } public class AnnounceControls { - private ParseObject ParseObj = null; - - private bool greet; + private Classes._DataModels.Announcement _model { get; } public bool Greet { - get { return greet; } - set { greet = value; Save(); } + get { return _model.Greet; } + set { _model.Greet = value; Save(); } } - private ulong greetChannel; - public ulong GreetChannel { - get { return greetChannel; } - set { greetChannel = value; } + get { return (ulong)_model.GreetChannelId; } + set { _model.GreetChannelId = (long)value; Save(); } } - private bool greetPM; - public bool GreetPM { - get { return greetPM; } - set { greetPM = value; Save(); } + get { return _model.GreetPM; } + set { _model.GreetPM = value; Save(); } } - private bool byePM; - public bool ByePM { - get { return byePM; } - set { byePM = value; Save(); } + get { return _model.ByePM; } + set { _model.ByePM = value; Save(); } } - private string greetText = "Welcome to the server %user%"; public string GreetText { - get { return greetText; } - set { greetText = value; Save(); } + get { return _model.GreetText; } + set { _model.GreetText = value; Save(); } } - private bool bye; - public bool Bye { - get { return bye; } - set { bye = value; Save(); } + get { return _model.Bye; } + set { _model.Bye = value; Save(); } } - - private ulong byeChannel; - public ulong ByeChannel { - get { return byeChannel; } - set { byeChannel = value; } + get { return (ulong)_model.ByeChannelId; } + set { _model.ByeChannelId = (long)value; Save(); } } - private string byeText = "%user% has left the server"; public string ByeText { - get { return byeText; } - set { byeText = value; Save(); } + get { return _model.ByeText; } + set { _model.ByeText = value; Save(); } } - - public ulong ServerId { get; } + public ulong ServerId { + get { return (ulong)_model.ServerId; } + set { _model.ServerId = (long)value; } + } + + public AnnounceControls(Classes._DataModels.Announcement model) { + this._model = model; + } public AnnounceControls(ulong serverId) { - this.ServerId = serverId; + this._model = new Classes._DataModels.Announcement(); + ServerId = serverId; } internal bool ToggleBye(ulong id) { @@ -169,39 +160,7 @@ namespace NadekoBot.Commands { internal bool ToggleByePM() => ByePM = !ByePM; private void Save() { - ParseObject p = null; - if (this.ParseObj != null) - p = ParseObj; - else - p = ParseObj = new ParseObject("Announcements"); - p["greet"] = greet; - p["greetPM"] = greetPM; - p["greetText"] = greetText; - p["greetChannel"] = greetChannel; - - p["bye"] = bye; - p["byePM"] = byePM; - p["byeText"] = byeText; - p["byeChannel"] = byeChannel; - - p["serverId"] = ServerId; - - p.SaveAsync(); - } - - internal AnnounceControls Initialize(ParseObject po) { - greet = po.Get("greet"); - greetPM = po.ContainsKey("greetPM") ? po.Get("greetPM") : false; - greetText = po.Get("greetText"); - greetChannel = po.Get("greetChannel"); - - bye = po.Get("bye"); - byePM = po.ContainsKey("byePM") ? po.Get("byePM") : false; - byeText = po.Get("byeText"); - byeChannel = po.Get("byeChannel"); - - this.ParseObj = po; - return this; + Classes.DBHandler.Instance.Save(_model); } } @@ -268,7 +227,7 @@ namespace NadekoBot.Commands { AnnouncementsDictionary[e.Server.Id].ByeText = e.GetArg("msg"); await e.Send("New bye message set."); if (!AnnouncementsDictionary[e.Server.Id].Bye) - await e.Send("Enable bye messsages by typing `.bye`, and set the bye message using `.byemsg`"); + await e.Send("Enable bye messsages by typing `.bye`."); }); cgb.CreateCommand(".byepm") diff --git a/NadekoBot/Commands/SpeedTyping.cs b/NadekoBot/Commands/SpeedTyping.cs index 267b792f..98564331 100644 --- a/NadekoBot/Commands/SpeedTyping.cs +++ b/NadekoBot/Commands/SpeedTyping.cs @@ -151,13 +151,13 @@ namespace NadekoBot.Commands { cgb.CreateCommand("typeadd") .Description("Adds a new article to the typing contest. Owner only.") .Parameter("text",ParameterType.Unparsed) - .Do(async e => { - if (e.User.Id != NadekoBot.OwnerID || e.GetArg("text") == null) return; + .Do(e => { + if (e.User.Id != NadekoBot.OwnerID || string.IsNullOrWhiteSpace(e.GetArg("text"))) return; - var obj = new ParseObject("TypingArticles"); - obj["text"] = e.GetArg("text"); - await obj.SaveAsync(); - await e.Send("Added new typing article."); + Classes.DBHandler.Instance.InsertData(new Classes._DataModels.TypingArticle { + Text = e.GetArg("text"), + DateAdded = DateTime.Now + }); }); //todo add user submissions diff --git a/NadekoBot/NadekoBot.cs b/NadekoBot/NadekoBot.cs index 89fed16e..2bb396e5 100644 --- a/NadekoBot/NadekoBot.cs +++ b/NadekoBot/NadekoBot.cs @@ -173,4 +173,6 @@ namespace NadekoBot { } } } -} \ No newline at end of file +} + +//95520984584429568 meany \ No newline at end of file diff --git a/NadekoBot/NadekoBot.csproj b/NadekoBot/NadekoBot.csproj index b2729ef1..8bc35565 100644 --- a/NadekoBot/NadekoBot.csproj +++ b/NadekoBot/NadekoBot.csproj @@ -30,6 +30,8 @@ false true true + + AnyCPU @@ -40,6 +42,7 @@ DEBUG;TRACE prompt 4 + true AnyCPU @@ -127,16 +130,21 @@ True - - + + + + + + + @@ -153,6 +161,7 @@ + @@ -170,6 +179,8 @@ Resources.resx + + diff --git a/NadekoBot/SQLite.cs b/NadekoBot/SQLite.cs new file mode 100644 index 00000000..e74447f6 --- /dev/null +++ b/NadekoBot/SQLite.cs @@ -0,0 +1,3278 @@ +// +// Copyright (c) 2009-2012 Krueger Systems, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// +#if WINDOWS_PHONE && !USE_WP8_NATIVE_SQLITE +#define USE_CSHARP_SQLITE +#endif + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Collections.Generic; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; + +#if USE_CSHARP_SQLITE +using Sqlite3 = Community.CsharpSqlite.Sqlite3; +using Sqlite3DatabaseHandle = Community.CsharpSqlite.Sqlite3.sqlite3; +using Sqlite3Statement = Community.CsharpSqlite.Sqlite3.Vdbe; +#elif USE_WP8_NATIVE_SQLITE +using Sqlite3 = Sqlite.Sqlite3; +using Sqlite3DatabaseHandle = Sqlite.Database; +using Sqlite3Statement = Sqlite.Statement; +#else +using Sqlite3DatabaseHandle = System.IntPtr; +using Sqlite3Statement = System.IntPtr; +#endif + +namespace SQLite +{ + public class SQLiteException : Exception + { + public SQLite3.Result Result { get; private set; } + + protected SQLiteException (SQLite3.Result r,string message) : base(message) + { + Result = r; + } + + public static SQLiteException New (SQLite3.Result r, string message) + { + return new SQLiteException (r, message); + } + } + + public class NotNullConstraintViolationException : SQLiteException + { + public IEnumerable Columns { get; protected set; } + + protected NotNullConstraintViolationException (SQLite3.Result r, string message) + : this (r, message, null, null) + { + + } + + protected NotNullConstraintViolationException (SQLite3.Result r, string message, TableMapping mapping, object obj) + : base (r, message) + { + if (mapping != null && obj != null) { + this.Columns = from c in mapping.Columns + where c.IsNullable == false && c.GetValue (obj) == null + select c; + } + } + + public static new NotNullConstraintViolationException New (SQLite3.Result r, string message) + { + return new NotNullConstraintViolationException (r, message); + } + + public static NotNullConstraintViolationException New (SQLite3.Result r, string message, TableMapping mapping, object obj) + { + return new NotNullConstraintViolationException (r, message, mapping, obj); + } + + public static NotNullConstraintViolationException New (SQLiteException exception, TableMapping mapping, object obj) + { + return new NotNullConstraintViolationException (exception.Result, exception.Message, mapping, obj); + } + } + + [Flags] + public enum SQLiteOpenFlags { + ReadOnly = 1, ReadWrite = 2, Create = 4, + NoMutex = 0x8000, FullMutex = 0x10000, + SharedCache = 0x20000, PrivateCache = 0x40000, + ProtectionComplete = 0x00100000, + ProtectionCompleteUnlessOpen = 0x00200000, + ProtectionCompleteUntilFirstUserAuthentication = 0x00300000, + ProtectionNone = 0x00400000 + } + + [Flags] + public enum CreateFlags + { + None = 0, + ImplicitPK = 1, // create a primary key for field called 'Id' (Orm.ImplicitPkName) + ImplicitIndex = 2, // create an index for fields ending in 'Id' (Orm.ImplicitIndexSuffix) + AllImplicit = 3, // do both above + + AutoIncPK = 4 // force PK field to be auto inc + } + + /// + /// Represents an open connection to a SQLite database. + /// + public partial class SQLiteConnection : IDisposable + { + private bool _open; + private TimeSpan _busyTimeout; + private Dictionary _mappings = null; + private Dictionary _tables = null; + private System.Diagnostics.Stopwatch _sw; + private long _elapsedMilliseconds = 0; + + private int _transactionDepth = 0; + private Random _rand = new Random (); + + public Sqlite3DatabaseHandle Handle { get; private set; } + internal static readonly Sqlite3DatabaseHandle NullHandle = default(Sqlite3DatabaseHandle); + + public string DatabasePath { get; private set; } + + public bool TimeExecution { get; set; } + + public bool Trace { get; set; } + + public bool StoreDateTimeAsTicks { get; private set; } + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The default of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// + public SQLiteConnection (string databasePath, bool storeDateTimeAsTicks = false) + : this (databasePath, SQLiteOpenFlags.ReadWrite | SQLiteOpenFlags.Create, storeDateTimeAsTicks) + { + } + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The default of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// + public SQLiteConnection (string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks = false) + { + if (string.IsNullOrEmpty (databasePath)) + throw new ArgumentException ("Must be specified", "databasePath"); + + DatabasePath = databasePath; + +#if NETFX_CORE + SQLite3.SetDirectory(/*temp directory type*/2, Windows.Storage.ApplicationData.Current.TemporaryFolder.Path); +#endif + + Sqlite3DatabaseHandle handle; + +#if SILVERLIGHT || USE_CSHARP_SQLITE + var r = SQLite3.Open (databasePath, out handle, (int)openFlags, IntPtr.Zero); +#else + // open using the byte[] + // in the case where the path may include Unicode + // force open to using UTF-8 using sqlite3_open_v2 + var databasePathAsBytes = GetNullTerminatedUtf8 (DatabasePath); + var r = SQLite3.Open (databasePathAsBytes, out handle, (int) openFlags, IntPtr.Zero); +#endif + + Handle = handle; + if (r != SQLite3.Result.OK) { + throw SQLiteException.New (r, String.Format ("Could not open database file: {0} ({1})", DatabasePath, r)); + } + _open = true; + + StoreDateTimeAsTicks = storeDateTimeAsTicks; + + BusyTimeout = TimeSpan.FromSeconds (0.1); + } + + static SQLiteConnection () + { + if (_preserveDuringLinkMagic) { + var ti = new ColumnInfo (); + ti.Name = "magic"; + } + } + + public void EnableLoadExtension(int onoff) + { + SQLite3.Result r = SQLite3.EnableLoadExtension(Handle, onoff); + if (r != SQLite3.Result.OK) { + string msg = SQLite3.GetErrmsg (Handle); + throw SQLiteException.New (r, msg); + } + } + + static byte[] GetNullTerminatedUtf8 (string s) + { + var utf8Length = System.Text.Encoding.UTF8.GetByteCount (s); + var bytes = new byte [utf8Length + 1]; + utf8Length = System.Text.Encoding.UTF8.GetBytes(s, 0, s.Length, bytes, 0); + return bytes; + } + + /// + /// Used to list some code that we want the MonoTouch linker + /// to see, but that we never want to actually execute. + /// + static bool _preserveDuringLinkMagic; + + /// + /// Sets a busy handler to sleep the specified amount of time when a table is locked. + /// The handler will sleep multiple times until a total time of has accumulated. + /// + public TimeSpan BusyTimeout { + get { return _busyTimeout; } + set { + _busyTimeout = value; + if (Handle != NullHandle) { + SQLite3.BusyTimeout (Handle, (int)_busyTimeout.TotalMilliseconds); + } + } + } + + /// + /// Returns the mappings from types to tables that the connection + /// currently understands. + /// + public IEnumerable TableMappings { + get { + return _tables != null ? _tables.Values : Enumerable.Empty (); + } + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// The type whose mapping to the database is returned. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public TableMapping GetMapping(Type type, CreateFlags createFlags = CreateFlags.None) + { + if (_mappings == null) { + _mappings = new Dictionary (); + } + TableMapping map; + if (!_mappings.TryGetValue (type.FullName, out map)) { + map = new TableMapping (type, createFlags); + _mappings [type.FullName] = map; + } + return map; + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public TableMapping GetMapping () + { + return GetMapping (typeof (T)); + } + + private struct IndexedColumn + { + public int Order; + public string ColumnName; + } + + private struct IndexInfo + { + public string IndexName; + public string TableName; + public bool Unique; + public List Columns; + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + public int DropTable() + { + var map = GetMapping (typeof (T)); + + var query = string.Format("drop table if exists \"{0}\"", map.TableName); + + return Execute (query); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// The number of entries added to the database schema. + /// + public int CreateTable(CreateFlags createFlags = CreateFlags.None) + { + return CreateTable(typeof (T), createFlags); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// Type to reflect to a database table. + /// Optional flags allowing implicit PK and indexes based on naming conventions. + /// + /// The number of entries added to the database schema. + /// + public int CreateTable(Type ty, CreateFlags createFlags = CreateFlags.None) + { + if (_tables == null) { + _tables = new Dictionary (); + } + TableMapping map; + if (!_tables.TryGetValue (ty.FullName, out map)) { + map = GetMapping (ty, createFlags); + _tables.Add (ty.FullName, map); + } + var query = "create table if not exists \"" + map.TableName + "\"(\n"; + + var decls = map.Columns.Select (p => Orm.SqlDecl (p, StoreDateTimeAsTicks)); + var decl = string.Join (",\n", decls.ToArray ()); + query += decl; + query += ")"; + + var count = Execute (query); + + if (count == 0) { //Possible bug: This always seems to return 0? + // Table already exists, migrate it + MigrateTable (map); + } + + var indexes = new Dictionary (); + foreach (var c in map.Columns) { + foreach (var i in c.Indices) { + var iname = i.Name ?? map.TableName + "_" + c.Name; + IndexInfo iinfo; + if (!indexes.TryGetValue (iname, out iinfo)) { + iinfo = new IndexInfo { + IndexName = iname, + TableName = map.TableName, + Unique = i.Unique, + Columns = new List () + }; + indexes.Add (iname, iinfo); + } + + if (i.Unique != iinfo.Unique) + throw new Exception ("All the columns in an index must have the same value for their Unique property"); + + iinfo.Columns.Add (new IndexedColumn { + Order = i.Order, + ColumnName = c.Name + }); + } + } + + foreach (var indexName in indexes.Keys) { + var index = indexes[indexName]; + var columns = index.Columns.OrderBy(i => i.Order).Select(i => i.ColumnName).ToArray(); + count += CreateIndex(indexName, index.TableName, columns, index.Unique); + } + + return count; + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the index to create + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + public int CreateIndex(string indexName, string tableName, string[] columnNames, bool unique = false) + { + const string sqlFormat = "create {2} index if not exists \"{3}\" on \"{0}\"(\"{1}\")"; + var sql = String.Format(sqlFormat, tableName, string.Join ("\", \"", columnNames), unique ? "unique" : "", indexName); + return Execute(sql); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the index to create + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + public int CreateIndex(string indexName, string tableName, string columnName, bool unique = false) + { + return CreateIndex(indexName, tableName, new string[] { columnName }, unique); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + public int CreateIndex(string tableName, string columnName, bool unique = false) + { + return CreateIndex(tableName + "_" + columnName, tableName, columnName, unique); + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + public int CreateIndex(string tableName, string[] columnNames, bool unique = false) + { + return CreateIndex(tableName + "_" + string.Join ("_", columnNames), tableName, columnNames, unique); + } + + /// + /// Creates an index for the specified object property. + /// e.g. CreateIndex(c => c.Name); + /// + /// Type to reflect to a database table. + /// Property to index + /// Whether the index should be unique + public void CreateIndex(Expression> property, bool unique = false) + { + MemberExpression mx; + if (property.Body.NodeType == ExpressionType.Convert) + { + mx = ((UnaryExpression)property.Body).Operand as MemberExpression; + } + else + { + mx= (property.Body as MemberExpression); + } + var propertyInfo = mx.Member as PropertyInfo; + if (propertyInfo == null) + { + throw new ArgumentException("The lambda expression 'property' should point to a valid Property"); + } + + var propName = propertyInfo.Name; + + var map = GetMapping(); + var colName = map.FindColumnWithPropertyName(propName).Name; + + CreateIndex(map.TableName, colName, unique); + } + + public class ColumnInfo + { +// public int cid { get; set; } + + [Column ("name")] + public string Name { get; set; } + +// [Column ("type")] +// public string ColumnType { get; set; } + + public int notnull { get; set; } + +// public string dflt_value { get; set; } + +// public int pk { get; set; } + + public override string ToString () + { + return Name; + } + } + + public List GetTableInfo (string tableName) + { + var query = "pragma table_info(\"" + tableName + "\")"; + return Query (query); + } + + void MigrateTable (TableMapping map) + { + var existingCols = GetTableInfo (map.TableName); + + var toBeAdded = new List (); + + foreach (var p in map.Columns) { + var found = false; + foreach (var c in existingCols) { + found = (string.Compare (p.Name, c.Name, StringComparison.OrdinalIgnoreCase) == 0); + if (found) + break; + } + if (!found) { + toBeAdded.Add (p); + } + } + + foreach (var p in toBeAdded) { + var addCol = "alter table \"" + map.TableName + "\" add column " + Orm.SqlDecl (p, StoreDateTimeAsTicks); + Execute (addCol); + } + } + + /// + /// Creates a new SQLiteCommand. Can be overridden to provide a sub-class. + /// + /// + protected virtual SQLiteCommand NewCommand () + { + return new SQLiteCommand (this); + } + + /// + /// Creates a new SQLiteCommand given the command text with arguments. Place a '?' + /// in the command text for each of the arguments. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the command text. + /// + /// + /// A + /// + public SQLiteCommand CreateCommand (string cmdText, params object[] ps) + { + if (!_open) + throw SQLiteException.New (SQLite3.Result.Error, "Cannot create commands from unopened database"); + + var cmd = NewCommand (); + cmd.CommandText = cmdText; + foreach (var o in ps) { + cmd.Bind (o); + } + return cmd; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method instead of Query when you don't expect rows back. Such cases include + /// INSERTs, UPDATEs, and DELETEs. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public int Execute (string query, params object[] args) + { + var cmd = CreateCommand (query, args); + + if (TimeExecution) { + if (_sw == null) { + _sw = new Stopwatch (); + } + _sw.Reset (); + _sw.Start (); + } + + var r = cmd.ExecuteNonQuery (); + + if (TimeExecution) { + _sw.Stop (); + _elapsedMilliseconds += _sw.ElapsedMilliseconds; + Debug.WriteLine (string.Format ("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds / 1000.0)); + } + + return r; + } + + public T ExecuteScalar (string query, params object[] args) + { + var cmd = CreateCommand (query, args); + + if (TimeExecution) { + if (_sw == null) { + _sw = new Stopwatch (); + } + _sw.Reset (); + _sw.Start (); + } + + var r = cmd.ExecuteScalar (); + + if (TimeExecution) { + _sw.Stop (); + _elapsedMilliseconds += _sw.ElapsedMilliseconds; + Debug.WriteLine (string.Format ("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds / 1000.0)); + } + + return r; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public List Query (string query, params object[] args) where T : new() + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteQuery (); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public IEnumerable DeferredQuery(string query, params object[] args) where T : new() + { + var cmd = CreateCommand(query, args); + return cmd.ExecuteDeferredQuery(); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public List Query (TableMapping map, string query, params object[] args) + { + var cmd = CreateCommand (query, args); + return cmd.ExecuteQuery (map); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public IEnumerable DeferredQuery(TableMapping map, string query, params object[] args) + { + var cmd = CreateCommand(query, args); + return cmd.ExecuteDeferredQuery(map); + } + + /// + /// Returns a queryable interface to the table represented by the given type. + /// + /// + /// A queryable object that is able to translate Where, OrderBy, and Take + /// queries into native SQL. + /// + public TableQuery Table () where T : new() + { + return new TableQuery (this); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public T Get (object pk) where T : new() + { + var map = GetMapping (typeof(T)); + return Query (map.GetByPrimaryKeySql, pk).First (); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate. Throws a not found exception + /// if the object is not found. + /// + public T Get (Expression> predicate) where T : new() + { + return Table ().Where (predicate).First (); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public T Find (object pk) where T : new () + { + var map = GetMapping (typeof (T)); + return Query (map.GetByPrimaryKeySql, pk).FirstOrDefault (); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the object type. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public object Find (object pk, TableMapping map) + { + return Query (map, map.GetByPrimaryKeySql, pk).FirstOrDefault (); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public T Find (Expression> predicate) where T : new() + { + return Table ().Where (predicate).FirstOrDefault (); + } + + /// + /// Whether has been called and the database is waiting for a . + /// + public bool IsInTransaction { + get { return _transactionDepth > 0; } + } + + /// + /// Begins a new transaction. Call to end the transaction. + /// + /// Throws if a transaction has already begun. + public void BeginTransaction () + { + // The BEGIN command only works if the transaction stack is empty, + // or in other words if there are no pending transactions. + // If the transaction stack is not empty when the BEGIN command is invoked, + // then the command fails with an error. + // Rather than crash with an error, we will just ignore calls to BeginTransaction + // that would result in an error. + if (Interlocked.CompareExchange (ref _transactionDepth, 1, 0) == 0) { + try { + Execute ("begin transaction"); + } catch (Exception ex) { + var sqlExp = ex as SQLiteException; + if (sqlExp != null) { + // It is recommended that applications respond to the errors listed below + // by explicitly issuing a ROLLBACK command. + // TODO: This rollback failsafe should be localized to all throw sites. + switch (sqlExp.Result) { + case SQLite3.Result.IOError: + case SQLite3.Result.Full: + case SQLite3.Result.Busy: + case SQLite3.Result.NoMem: + case SQLite3.Result.Interrupt: + RollbackTo (null, true); + break; + } + } else { + // Call decrement and not VolatileWrite in case we've already + // created a transaction point in SaveTransactionPoint since the catch. + Interlocked.Decrement (ref _transactionDepth); + } + + throw; + } + } else { + // Calling BeginTransaction on an already open transaction is invalid + throw new InvalidOperationException ("Cannot begin a transaction while already in a transaction."); + } + } + + /// + /// Creates a savepoint in the database at the current point in the transaction timeline. + /// Begins a new transaction if one is not in progress. + /// + /// Call to undo transactions since the returned savepoint. + /// Call to commit transactions after the savepoint returned here. + /// Call to end the transaction, committing all changes. + /// + /// A string naming the savepoint. + public string SaveTransactionPoint () + { + int depth = Interlocked.Increment (ref _transactionDepth) - 1; + string retVal = "S" + _rand.Next (short.MaxValue) + "D" + depth; + + try { + Execute ("savepoint " + retVal); + } catch (Exception ex) { + var sqlExp = ex as SQLiteException; + if (sqlExp != null) { + // It is recommended that applications respond to the errors listed below + // by explicitly issuing a ROLLBACK command. + // TODO: This rollback failsafe should be localized to all throw sites. + switch (sqlExp.Result) { + case SQLite3.Result.IOError: + case SQLite3.Result.Full: + case SQLite3.Result.Busy: + case SQLite3.Result.NoMem: + case SQLite3.Result.Interrupt: + RollbackTo (null, true); + break; + } + } else { + Interlocked.Decrement (ref _transactionDepth); + } + + throw; + } + + return retVal; + } + + /// + /// Rolls back the transaction that was begun by or . + /// + public void Rollback () + { + RollbackTo (null, false); + } + + /// + /// Rolls back the savepoint created by or SaveTransactionPoint. + /// + /// The name of the savepoint to roll back to, as returned by . If savepoint is null or empty, this method is equivalent to a call to + public void RollbackTo (string savepoint) + { + RollbackTo (savepoint, false); + } + + /// + /// Rolls back the transaction that was begun by . + /// + /// true to avoid throwing exceptions, false otherwise + void RollbackTo (string savepoint, bool noThrow) + { + // Rolling back without a TO clause rolls backs all transactions + // and leaves the transaction stack empty. + try { + if (String.IsNullOrEmpty (savepoint)) { + if (Interlocked.Exchange (ref _transactionDepth, 0) > 0) { + Execute ("rollback"); + } + } else { + DoSavePointExecute (savepoint, "rollback to "); + } + } catch (SQLiteException) { + if (!noThrow) + throw; + + } + // No need to rollback if there are no transactions open. + } + + /// + /// Releases a savepoint returned from . Releasing a savepoint + /// makes changes since that savepoint permanent if the savepoint began the transaction, + /// or otherwise the changes are permanent pending a call to . + /// + /// The RELEASE command is like a COMMIT for a SAVEPOINT. + /// + /// The name of the savepoint to release. The string should be the result of a call to + public void Release (string savepoint) + { + DoSavePointExecute (savepoint, "release "); + } + + void DoSavePointExecute (string savepoint, string cmd) + { + // Validate the savepoint + int firstLen = savepoint.IndexOf ('D'); + if (firstLen >= 2 && savepoint.Length > firstLen + 1) { + int depth; + if (Int32.TryParse (savepoint.Substring (firstLen + 1), out depth)) { + // TODO: Mild race here, but inescapable without locking almost everywhere. + if (0 <= depth && depth < _transactionDepth) { +#if NETFX_CORE + Volatile.Write (ref _transactionDepth, depth); +#elif SILVERLIGHT + _transactionDepth = depth; +#else + Thread.VolatileWrite (ref _transactionDepth, depth); +#endif + Execute (cmd + savepoint); + return; + } + } + } + + throw new ArgumentException ("savePoint is not valid, and should be the result of a call to SaveTransactionPoint.", "savePoint"); + } + + /// + /// Commits the transaction that was begun by . + /// + public void Commit () + { + if (Interlocked.Exchange (ref _transactionDepth, 0) != 0) { + Execute ("commit"); + } + // Do nothing on a commit with no open transaction + } + + /// + /// Executes within a (possibly nested) transaction by wrapping it in a SAVEPOINT. If an + /// exception occurs the whole transaction is rolled back, not just the current savepoint. The exception + /// is rethrown. + /// + /// + /// The to perform within a transaction. can contain any number + /// of operations on the connection but should never call or + /// . + /// + public void RunInTransaction (Action action) + { + try { + var savePoint = SaveTransactionPoint (); + action (); + Release (savePoint); + } catch (Exception) { + Rollback (); + throw; + } + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll (System.Collections.IEnumerable objects) + { + var c = 0; + RunInTransaction(() => { + foreach (var r in objects) { + c += Insert (r); + } + }); + return c; + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll (System.Collections.IEnumerable objects, string extra) + { + var c = 0; + RunInTransaction (() => { + foreach (var r in objects) { + c += Insert (r, extra); + } + }); + return c; + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll (System.Collections.IEnumerable objects, Type objType) + { + var c = 0; + RunInTransaction (() => { + foreach (var r in objects) { + c += Insert (r, objType); + } + }); + return c; + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj) + { + if (obj == null) { + return 0; + } + return Insert (obj, "", obj.GetType ()); + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows modified. + /// + public int InsertOrReplace (object obj) + { + if (obj == null) { + return 0; + } + return Insert (obj, "OR REPLACE", obj.GetType ()); + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj, Type objType) + { + return Insert (obj, "", objType); + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows modified. + /// + public int InsertOrReplace (object obj, Type objType) + { + return Insert (obj, "OR REPLACE", objType); + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj, string extra) + { + if (obj == null) { + return 0; + } + return Insert (obj, extra, obj.GetType ()); + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert (object obj, string extra, Type objType) + { + if (obj == null || objType == null) { + return 0; + } + + + var map = GetMapping (objType); + +#if NETFX_CORE + if (map.PK != null && map.PK.IsAutoGuid) + { + // no GetProperty so search our way up the inheritance chain till we find it + PropertyInfo prop; + while (objType != null) + { + var info = objType.GetTypeInfo(); + prop = info.GetDeclaredProperty(map.PK.PropertyName); + if (prop != null) + { + if (prop.GetValue(obj, null).Equals(Guid.Empty)) + { + prop.SetValue(obj, Guid.NewGuid(), null); + } + break; + } + + objType = info.BaseType; + } + } +#else + if (map.PK != null && map.PK.IsAutoGuid) { + var prop = objType.GetProperty(map.PK.PropertyName); + if (prop != null) { + if (prop.GetValue(obj, null).Equals(Guid.Empty)) { + prop.SetValue(obj, Guid.NewGuid(), null); + } + } + } +#endif + + + var replacing = string.Compare (extra, "OR REPLACE", StringComparison.OrdinalIgnoreCase) == 0; + + var cols = replacing ? map.InsertOrReplaceColumns : map.InsertColumns; + var vals = new object[cols.Length]; + for (var i = 0; i < vals.Length; i++) { + vals [i] = cols [i].GetValue (obj); + } + + var insertCmd = map.GetInsertCommand (this, extra); + int count; + + try { + count = insertCmd.ExecuteNonQuery (vals); + } + catch (SQLiteException ex) { + + if (SQLite3.ExtendedErrCode (this.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (ex.Result, ex.Message, map, obj); + } + throw; + } + + if (map.HasAutoIncPK) + { + var id = SQLite3.LastInsertRowid (Handle); + map.SetAutoIncPK (obj, id); + } + + return count; + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows updated. + /// + public int Update (object obj) + { + if (obj == null) { + return 0; + } + return Update (obj, obj.GetType ()); + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows updated. + /// + public int Update (object obj, Type objType) + { + int rowsAffected = 0; + if (obj == null || objType == null) { + return 0; + } + + var map = GetMapping (objType); + + var pk = map.PK; + + if (pk == null) { + throw new NotSupportedException ("Cannot update " + map.TableName + ": it has no PK"); + } + + var cols = from p in map.Columns + where p != pk + select p; + var vals = from c in cols + select c.GetValue (obj); + var ps = new List (vals); + ps.Add (pk.GetValue (obj)); + var q = string.Format ("update \"{0}\" set {1} where {2} = ? ", map.TableName, string.Join (",", (from c in cols + select "\"" + c.Name + "\" = ? ").ToArray ()), pk.Name); + + try { + rowsAffected = Execute (q, ps.ToArray ()); + } + catch (SQLiteException ex) { + + if (ex.Result == SQLite3.Result.Constraint && SQLite3.ExtendedErrCode (this.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (ex, map, obj); + } + + throw ex; + } + + return rowsAffected; + } + + /// + /// Updates all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// The number of rows modified. + /// + public int UpdateAll (System.Collections.IEnumerable objects) + { + var c = 0; + RunInTransaction (() => { + foreach (var r in objects) { + c += Update (r); + } + }); + return c; + } + + /// + /// Deletes the given object from the database using its primary key. + /// + /// + /// The object to delete. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows deleted. + /// + public int Delete (object objectToDelete) + { + var map = GetMapping (objectToDelete.GetType ()); + var pk = map.PK; + if (pk == null) { + throw new NotSupportedException ("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format ("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + return Execute (q, pk.GetValue (objectToDelete)); + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of object. + /// + public int Delete (object primaryKey) + { + var map = GetMapping (typeof (T)); + var pk = map.PK; + if (pk == null) { + throw new NotSupportedException ("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format ("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + return Execute (q, primaryKey); + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of objects to delete. + /// + public int DeleteAll () + { + var map = GetMapping (typeof (T)); + var query = string.Format("delete from \"{0}\"", map.TableName); + return Execute (query); + } + + ~SQLiteConnection () + { + Dispose (false); + } + + public void Dispose () + { + Dispose (true); + GC.SuppressFinalize (this); + } + + protected virtual void Dispose (bool disposing) + { + Close (); + } + + public void Close () + { + if (_open && Handle != NullHandle) { + try { + if (_mappings != null) { + foreach (var sqlInsertCommand in _mappings.Values) { + sqlInsertCommand.Dispose(); + } + } + var r = SQLite3.Close (Handle); + if (r != SQLite3.Result.OK) { + string msg = SQLite3.GetErrmsg (Handle); + throw SQLiteException.New (r, msg); + } + } + finally { + Handle = NullHandle; + _open = false; + } + } + } + } + + /// + /// Represents a parsed connection string. + /// + class SQLiteConnectionString + { + public string ConnectionString { get; private set; } + public string DatabasePath { get; private set; } + public bool StoreDateTimeAsTicks { get; private set; } + +#if NETFX_CORE + static readonly string MetroStyleDataPath = Windows.Storage.ApplicationData.Current.LocalFolder.Path; +#endif + + public SQLiteConnectionString (string databasePath, bool storeDateTimeAsTicks) + { + ConnectionString = databasePath; + StoreDateTimeAsTicks = storeDateTimeAsTicks; + +#if NETFX_CORE + DatabasePath = System.IO.Path.Combine (MetroStyleDataPath, databasePath); +#else + DatabasePath = databasePath; +#endif + } + } + + [AttributeUsage (AttributeTargets.Class)] + public class TableAttribute : Attribute + { + public string Name { get; set; } + + public TableAttribute (string name) + { + Name = name; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class ColumnAttribute : Attribute + { + public string Name { get; set; } + + public ColumnAttribute (string name) + { + Name = name; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class PrimaryKeyAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Property)] + public class AutoIncrementAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Property)] + public class IndexedAttribute : Attribute + { + public string Name { get; set; } + public int Order { get; set; } + public virtual bool Unique { get; set; } + + public IndexedAttribute() + { + } + + public IndexedAttribute(string name, int order) + { + Name = name; + Order = order; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class IgnoreAttribute : Attribute + { + } + + [AttributeUsage (AttributeTargets.Property)] + public class UniqueAttribute : IndexedAttribute + { + public override bool Unique { + get { return true; } + set { /* throw? */ } + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class MaxLengthAttribute : Attribute + { + public int Value { get; private set; } + + public MaxLengthAttribute (int length) + { + Value = length; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class CollationAttribute: Attribute + { + public string Value { get; private set; } + + public CollationAttribute (string collation) + { + Value = collation; + } + } + + [AttributeUsage (AttributeTargets.Property)] + public class NotNullAttribute : Attribute + { + } + + public class TableMapping + { + public Type MappedType { get; private set; } + + public string TableName { get; private set; } + + public Column[] Columns { get; private set; } + + public Column PK { get; private set; } + + public string GetByPrimaryKeySql { get; private set; } + + Column _autoPk; + Column[] _insertColumns; + Column[] _insertOrReplaceColumns; + + public TableMapping(Type type, CreateFlags createFlags = CreateFlags.None) + { + MappedType = type; + +#if NETFX_CORE + var tableAttr = (TableAttribute)System.Reflection.CustomAttributeExtensions + .GetCustomAttribute(type.GetTypeInfo(), typeof(TableAttribute), true); +#else + var tableAttr = (TableAttribute)type.GetCustomAttributes (typeof (TableAttribute), true).FirstOrDefault (); +#endif + + TableName = tableAttr != null ? tableAttr.Name : MappedType.Name; + +#if !NETFX_CORE + var props = MappedType.GetProperties (BindingFlags.Public | BindingFlags.Instance | BindingFlags.SetProperty); +#else + var props = from p in MappedType.GetRuntimeProperties() + where ((p.GetMethod != null && p.GetMethod.IsPublic) || (p.SetMethod != null && p.SetMethod.IsPublic) || (p.GetMethod != null && p.GetMethod.IsStatic) || (p.SetMethod != null && p.SetMethod.IsStatic)) + select p; +#endif + var cols = new List (); + foreach (var p in props) { +#if !NETFX_CORE + var ignore = p.GetCustomAttributes (typeof(IgnoreAttribute), true).Length > 0; +#else + var ignore = p.GetCustomAttributes (typeof(IgnoreAttribute), true).Count() > 0; +#endif + if (p.CanWrite && !ignore) { + cols.Add (new Column (p, createFlags)); + } + } + Columns = cols.ToArray (); + foreach (var c in Columns) { + if (c.IsAutoInc && c.IsPK) { + _autoPk = c; + } + if (c.IsPK) { + PK = c; + } + } + + HasAutoIncPK = _autoPk != null; + + if (PK != null) { + GetByPrimaryKeySql = string.Format ("select * from \"{0}\" where \"{1}\" = ?", TableName, PK.Name); + } + else { + // People should not be calling Get/Find without a PK + GetByPrimaryKeySql = string.Format ("select * from \"{0}\" limit 1", TableName); + } + } + + public bool HasAutoIncPK { get; private set; } + + public void SetAutoIncPK (object obj, long id) + { + if (_autoPk != null) { + _autoPk.SetValue (obj, Convert.ChangeType (id, _autoPk.ColumnType, null)); + } + } + + public Column[] InsertColumns { + get { + if (_insertColumns == null) { + _insertColumns = Columns.Where (c => !c.IsAutoInc).ToArray (); + } + return _insertColumns; + } + } + + public Column[] InsertOrReplaceColumns { + get { + if (_insertOrReplaceColumns == null) { + _insertOrReplaceColumns = Columns.ToArray (); + } + return _insertOrReplaceColumns; + } + } + + public Column FindColumnWithPropertyName (string propertyName) + { + var exact = Columns.FirstOrDefault (c => c.PropertyName == propertyName); + return exact; + } + + public Column FindColumn (string columnName) + { + var exact = Columns.FirstOrDefault (c => c.Name == columnName); + return exact; + } + + PreparedSqlLiteInsertCommand _insertCommand; + string _insertCommandExtra; + + public PreparedSqlLiteInsertCommand GetInsertCommand(SQLiteConnection conn, string extra) + { + if (_insertCommand == null) { + _insertCommand = CreateInsertCommand(conn, extra); + _insertCommandExtra = extra; + } + else if (_insertCommandExtra != extra) { + _insertCommand.Dispose(); + _insertCommand = CreateInsertCommand(conn, extra); + _insertCommandExtra = extra; + } + return _insertCommand; + } + + PreparedSqlLiteInsertCommand CreateInsertCommand(SQLiteConnection conn, string extra) + { + var cols = InsertColumns; + string insertSql; + if (!cols.Any() && Columns.Count() == 1 && Columns[0].IsAutoInc) + { + insertSql = string.Format("insert {1} into \"{0}\" default values", TableName, extra); + } + else + { + var replacing = string.Compare (extra, "OR REPLACE", StringComparison.OrdinalIgnoreCase) == 0; + + if (replacing) { + cols = InsertOrReplaceColumns; + } + + insertSql = string.Format("insert {3} into \"{0}\"({1}) values ({2})", TableName, + string.Join(",", (from c in cols + select "\"" + c.Name + "\"").ToArray()), + string.Join(",", (from c in cols + select "?").ToArray()), extra); + + } + + var insertCommand = new PreparedSqlLiteInsertCommand(conn); + insertCommand.CommandText = insertSql; + return insertCommand; + } + + protected internal void Dispose() + { + if (_insertCommand != null) { + _insertCommand.Dispose(); + _insertCommand = null; + } + } + + public class Column + { + PropertyInfo _prop; + + public string Name { get; private set; } + + public string PropertyName { get { return _prop.Name; } } + + public Type ColumnType { get; private set; } + + public string Collation { get; private set; } + + public bool IsAutoInc { get; private set; } + public bool IsAutoGuid { get; private set; } + + public bool IsPK { get; private set; } + + public IEnumerable Indices { get; set; } + + public bool IsNullable { get; private set; } + + public int? MaxStringLength { get; private set; } + + public Column(PropertyInfo prop, CreateFlags createFlags = CreateFlags.None) + { + var colAttr = (ColumnAttribute)prop.GetCustomAttributes(typeof(ColumnAttribute), true).FirstOrDefault(); + + _prop = prop; + Name = colAttr == null ? prop.Name : colAttr.Name; + //If this type is Nullable then Nullable.GetUnderlyingType returns the T, otherwise it returns null, so get the actual type instead + ColumnType = Nullable.GetUnderlyingType(prop.PropertyType) ?? prop.PropertyType; + Collation = Orm.Collation(prop); + + IsPK = Orm.IsPK(prop) || + (((createFlags & CreateFlags.ImplicitPK) == CreateFlags.ImplicitPK) && + string.Compare (prop.Name, Orm.ImplicitPkName, StringComparison.OrdinalIgnoreCase) == 0); + + var isAuto = Orm.IsAutoInc(prop) || (IsPK && ((createFlags & CreateFlags.AutoIncPK) == CreateFlags.AutoIncPK)); + IsAutoGuid = isAuto && ColumnType == typeof(Guid); + IsAutoInc = isAuto && !IsAutoGuid; + + Indices = Orm.GetIndices(prop); + if (!Indices.Any() + && !IsPK + && ((createFlags & CreateFlags.ImplicitIndex) == CreateFlags.ImplicitIndex) + && Name.EndsWith (Orm.ImplicitIndexSuffix, StringComparison.OrdinalIgnoreCase) + ) + { + Indices = new IndexedAttribute[] { new IndexedAttribute() }; + } + IsNullable = !(IsPK || Orm.IsMarkedNotNull(prop)); + MaxStringLength = Orm.MaxStringLength(prop); + } + + public void SetValue (object obj, object val) + { + _prop.SetValue (obj, val, null); + } + + public object GetValue (object obj) + { + return _prop.GetValue (obj, null); + } + } + } + + public static class Orm + { + public const int DefaultMaxStringLength = 140; + public const string ImplicitPkName = "Id"; + public const string ImplicitIndexSuffix = "Id"; + + public static string SqlDecl (TableMapping.Column p, bool storeDateTimeAsTicks) + { + string decl = "\"" + p.Name + "\" " + SqlType (p, storeDateTimeAsTicks) + " "; + + if (p.IsPK) { + decl += "primary key "; + } + if (p.IsAutoInc) { + decl += "autoincrement "; + } + if (!p.IsNullable) { + decl += "not null "; + } + if (!string.IsNullOrEmpty (p.Collation)) { + decl += "collate " + p.Collation + " "; + } + + return decl; + } + + public static string SqlType (TableMapping.Column p, bool storeDateTimeAsTicks) + { + var clrType = p.ColumnType; + if (clrType == typeof(Boolean) || clrType == typeof(Byte) || clrType == typeof(UInt16) || clrType == typeof(SByte) || clrType == typeof(Int16) || clrType == typeof(Int32)) { + return "integer"; + } else if (clrType == typeof(UInt32) || clrType == typeof(Int64)) { + return "bigint"; + } else if (clrType == typeof(Single) || clrType == typeof(Double) || clrType == typeof(Decimal)) { + return "float"; + } else if (clrType == typeof(String)) { + int? len = p.MaxStringLength; + + if (len.HasValue) + return "varchar(" + len.Value + ")"; + + return "varchar"; + } else if (clrType == typeof(TimeSpan)) { + return "bigint"; + } else if (clrType == typeof(DateTime)) { + return storeDateTimeAsTicks ? "bigint" : "datetime"; + } else if (clrType == typeof(DateTimeOffset)) { + return "bigint"; +#if !NETFX_CORE + } else if (clrType.IsEnum) { +#else + } else if (clrType.GetTypeInfo().IsEnum) { +#endif + return "integer"; + } else if (clrType == typeof(byte[])) { + return "blob"; + } else if (clrType == typeof(Guid)) { + return "varchar(36)"; + } else { + throw new NotSupportedException ("Don't know about " + clrType); + } + } + + public static bool IsPK (MemberInfo p) + { + var attrs = p.GetCustomAttributes (typeof(PrimaryKeyAttribute), true); +#if !NETFX_CORE + return attrs.Length > 0; +#else + return attrs.Count() > 0; +#endif + } + + public static string Collation (MemberInfo p) + { + var attrs = p.GetCustomAttributes (typeof(CollationAttribute), true); +#if !NETFX_CORE + if (attrs.Length > 0) { + return ((CollationAttribute)attrs [0]).Value; +#else + if (attrs.Count() > 0) { + return ((CollationAttribute)attrs.First()).Value; +#endif + } else { + return string.Empty; + } + } + + public static bool IsAutoInc (MemberInfo p) + { + var attrs = p.GetCustomAttributes (typeof(AutoIncrementAttribute), true); +#if !NETFX_CORE + return attrs.Length > 0; +#else + return attrs.Count() > 0; +#endif + } + + public static IEnumerable GetIndices(MemberInfo p) + { + var attrs = p.GetCustomAttributes(typeof(IndexedAttribute), true); + return attrs.Cast(); + } + + public static int? MaxStringLength(PropertyInfo p) + { + var attrs = p.GetCustomAttributes (typeof(MaxLengthAttribute), true); +#if !NETFX_CORE + if (attrs.Length > 0) + return ((MaxLengthAttribute)attrs [0]).Value; +#else + if (attrs.Count() > 0) + return ((MaxLengthAttribute)attrs.First()).Value; +#endif + + return null; + } + + public static bool IsMarkedNotNull(MemberInfo p) + { + var attrs = p.GetCustomAttributes (typeof (NotNullAttribute), true); +#if !NETFX_CORE + return attrs.Length > 0; +#else + return attrs.Count() > 0; +#endif + } + } + + public partial class SQLiteCommand + { + SQLiteConnection _conn; + private List _bindings; + + public string CommandText { get; set; } + + internal SQLiteCommand (SQLiteConnection conn) + { + _conn = conn; + _bindings = new List (); + CommandText = ""; + } + + public int ExecuteNonQuery () + { + if (_conn.Trace) { + Debug.WriteLine ("Executing: " + this); + } + + var r = SQLite3.Result.OK; + var stmt = Prepare (); + r = SQLite3.Step (stmt); + Finalize (stmt); + if (r == SQLite3.Result.Done) { + int rowsAffected = SQLite3.Changes (_conn.Handle); + return rowsAffected; + } else if (r == SQLite3.Result.Error) { + string msg = SQLite3.GetErrmsg (_conn.Handle); + throw SQLiteException.New (r, msg); + } + else if (r == SQLite3.Result.Constraint) { + if (SQLite3.ExtendedErrCode (_conn.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + throw NotNullConstraintViolationException.New (r, SQLite3.GetErrmsg (_conn.Handle)); + } + } + + throw SQLiteException.New(r, r.ToString()); + } + + public IEnumerable ExecuteDeferredQuery () + { + return ExecuteDeferredQuery(_conn.GetMapping(typeof(T))); + } + + public List ExecuteQuery () + { + return ExecuteDeferredQuery(_conn.GetMapping(typeof(T))).ToList(); + } + + public List ExecuteQuery (TableMapping map) + { + return ExecuteDeferredQuery(map).ToList(); + } + + /// + /// Invoked every time an instance is loaded from the database. + /// + /// + /// The newly created object. + /// + /// + /// This can be overridden in combination with the + /// method to hook into the life-cycle of objects. + /// + /// Type safety is not possible because MonoTouch does not support virtual generic methods. + /// + protected virtual void OnInstanceCreated (object obj) + { + // Can be overridden. + } + + public IEnumerable ExecuteDeferredQuery (TableMapping map) + { + if (_conn.Trace) { + Debug.WriteLine ("Executing Query: " + this); + } + + var stmt = Prepare (); + try + { + var cols = new TableMapping.Column[SQLite3.ColumnCount (stmt)]; + + for (int i = 0; i < cols.Length; i++) { + var name = SQLite3.ColumnName16 (stmt, i); + cols [i] = map.FindColumn (name); + } + + while (SQLite3.Step (stmt) == SQLite3.Result.Row) { + var obj = Activator.CreateInstance(map.MappedType); + for (int i = 0; i < cols.Length; i++) { + if (cols [i] == null) + continue; + var colType = SQLite3.ColumnType (stmt, i); + var val = ReadCol (stmt, i, colType, cols [i].ColumnType); + cols [i].SetValue (obj, val); + } + OnInstanceCreated (obj); + yield return (T)obj; + } + } + finally + { + SQLite3.Finalize(stmt); + } + } + + public T ExecuteScalar () + { + if (_conn.Trace) { + Debug.WriteLine ("Executing Query: " + this); + } + + T val = default(T); + + var stmt = Prepare (); + + try + { + var r = SQLite3.Step (stmt); + if (r == SQLite3.Result.Row) { + var colType = SQLite3.ColumnType (stmt, 0); + val = (T)ReadCol (stmt, 0, colType, typeof(T)); + } + else if (r == SQLite3.Result.Done) { + } + else + { + throw SQLiteException.New (r, SQLite3.GetErrmsg (_conn.Handle)); + } + } + finally + { + Finalize (stmt); + } + + return val; + } + + public void Bind (string name, object val) + { + _bindings.Add (new Binding { + Name = name, + Value = val + }); + } + + public void Bind (object val) + { + Bind (null, val); + } + + public override string ToString () + { + var parts = new string[1 + _bindings.Count]; + parts [0] = CommandText; + var i = 1; + foreach (var b in _bindings) { + parts [i] = string.Format (" {0}: {1}", i - 1, b.Value); + i++; + } + return string.Join (Environment.NewLine, parts); + } + + Sqlite3Statement Prepare() + { + var stmt = SQLite3.Prepare2 (_conn.Handle, CommandText); + BindAll (stmt); + return stmt; + } + + void Finalize (Sqlite3Statement stmt) + { + SQLite3.Finalize (stmt); + } + + void BindAll (Sqlite3Statement stmt) + { + int nextIdx = 1; + foreach (var b in _bindings) { + if (b.Name != null) { + b.Index = SQLite3.BindParameterIndex (stmt, b.Name); + } else { + b.Index = nextIdx++; + } + + BindParameter (stmt, b.Index, b.Value, _conn.StoreDateTimeAsTicks); + } + } + + internal static IntPtr NegativePointer = new IntPtr (-1); + + internal static void BindParameter (Sqlite3Statement stmt, int index, object value, bool storeDateTimeAsTicks) + { + if (value == null) { + SQLite3.BindNull (stmt, index); + } else { + if (value is Int32) { + SQLite3.BindInt (stmt, index, (int)value); + } else if (value is String) { + SQLite3.BindText (stmt, index, (string)value, -1, NegativePointer); + } else if (value is Byte || value is UInt16 || value is SByte || value is Int16) { + SQLite3.BindInt (stmt, index, Convert.ToInt32 (value)); + } else if (value is Boolean) { + SQLite3.BindInt (stmt, index, (bool)value ? 1 : 0); + } else if (value is UInt32 || value is Int64) { + SQLite3.BindInt64 (stmt, index, Convert.ToInt64 (value)); + } else if (value is Single || value is Double || value is Decimal) { + SQLite3.BindDouble (stmt, index, Convert.ToDouble (value)); + } else if (value is TimeSpan) { + SQLite3.BindInt64(stmt, index, ((TimeSpan)value).Ticks); + } else if (value is DateTime) { + if (storeDateTimeAsTicks) { + SQLite3.BindInt64 (stmt, index, ((DateTime)value).Ticks); + } + else { + SQLite3.BindText (stmt, index, ((DateTime)value).ToString ("yyyy-MM-dd HH:mm:ss"), -1, NegativePointer); + } + } else if (value is DateTimeOffset) { + SQLite3.BindInt64 (stmt, index, ((DateTimeOffset)value).UtcTicks); +#if !NETFX_CORE + } else if (value.GetType().IsEnum) { +#else + } else if (value.GetType().GetTypeInfo().IsEnum) { +#endif + SQLite3.BindInt (stmt, index, Convert.ToInt32 (value)); + } else if (value is byte[]){ + SQLite3.BindBlob(stmt, index, (byte[]) value, ((byte[]) value).Length, NegativePointer); + } else if (value is Guid) { + SQLite3.BindText(stmt, index, ((Guid)value).ToString(), 72, NegativePointer); + } else { + throw new NotSupportedException("Cannot store type: " + value.GetType()); + } + } + } + + class Binding + { + public string Name { get; set; } + + public object Value { get; set; } + + public int Index { get; set; } + } + + object ReadCol (Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clrType) + { + if (type == SQLite3.ColType.Null) { + return null; + } else { + if (clrType == typeof(String)) { + return SQLite3.ColumnString (stmt, index); + } else if (clrType == typeof(Int32)) { + return (int)SQLite3.ColumnInt (stmt, index); + } else if (clrType == typeof(Boolean)) { + return SQLite3.ColumnInt (stmt, index) == 1; + } else if (clrType == typeof(double)) { + return SQLite3.ColumnDouble (stmt, index); + } else if (clrType == typeof(float)) { + return (float)SQLite3.ColumnDouble (stmt, index); + } else if (clrType == typeof(TimeSpan)) { + return new TimeSpan(SQLite3.ColumnInt64(stmt, index)); + } else if (clrType == typeof(DateTime)) { + if (_conn.StoreDateTimeAsTicks) { + return new DateTime (SQLite3.ColumnInt64 (stmt, index)); + } + else { + var text = SQLite3.ColumnString (stmt, index); + return DateTime.Parse (text); + } + } else if (clrType == typeof(DateTimeOffset)) { + return new DateTimeOffset(SQLite3.ColumnInt64 (stmt, index),TimeSpan.Zero); +#if !NETFX_CORE + } else if (clrType.IsEnum) { +#else + } else if (clrType.GetTypeInfo().IsEnum) { +#endif + return SQLite3.ColumnInt (stmt, index); + } else if (clrType == typeof(Int64)) { + return SQLite3.ColumnInt64 (stmt, index); + } else if (clrType == typeof(UInt32)) { + return (uint)SQLite3.ColumnInt64 (stmt, index); + } else if (clrType == typeof(decimal)) { + return (decimal)SQLite3.ColumnDouble (stmt, index); + } else if (clrType == typeof(Byte)) { + return (byte)SQLite3.ColumnInt (stmt, index); + } else if (clrType == typeof(UInt16)) { + return (ushort)SQLite3.ColumnInt (stmt, index); + } else if (clrType == typeof(Int16)) { + return (short)SQLite3.ColumnInt (stmt, index); + } else if (clrType == typeof(sbyte)) { + return (sbyte)SQLite3.ColumnInt (stmt, index); + } else if (clrType == typeof(byte[])) { + return SQLite3.ColumnByteArray (stmt, index); + } else if (clrType == typeof(Guid)) { + var text = SQLite3.ColumnString(stmt, index); + return new Guid(text); + } else{ + throw new NotSupportedException ("Don't know how to read " + clrType); + } + } + } + } + + /// + /// Since the insert never changed, we only need to prepare once. + /// + public class PreparedSqlLiteInsertCommand : IDisposable + { + public bool Initialized { get; set; } + + protected SQLiteConnection Connection { get; set; } + + public string CommandText { get; set; } + + protected Sqlite3Statement Statement { get; set; } + internal static readonly Sqlite3Statement NullStatement = default(Sqlite3Statement); + + internal PreparedSqlLiteInsertCommand (SQLiteConnection conn) + { + Connection = conn; + } + + public int ExecuteNonQuery (object[] source) + { + if (Connection.Trace) { + Debug.WriteLine ("Executing: " + CommandText); + } + + var r = SQLite3.Result.OK; + + if (!Initialized) { + Statement = Prepare (); + Initialized = true; + } + + //bind the values. + if (source != null) { + for (int i = 0; i < source.Length; i++) { + SQLiteCommand.BindParameter (Statement, i + 1, source [i], Connection.StoreDateTimeAsTicks); + } + } + r = SQLite3.Step (Statement); + + if (r == SQLite3.Result.Done) { + int rowsAffected = SQLite3.Changes (Connection.Handle); + SQLite3.Reset (Statement); + return rowsAffected; + } else if (r == SQLite3.Result.Error) { + string msg = SQLite3.GetErrmsg (Connection.Handle); + SQLite3.Reset (Statement); + throw SQLiteException.New (r, msg); + } else if (r == SQLite3.Result.Constraint && SQLite3.ExtendedErrCode (Connection.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) { + SQLite3.Reset (Statement); + throw NotNullConstraintViolationException.New (r, SQLite3.GetErrmsg (Connection.Handle)); + } else { + SQLite3.Reset (Statement); + throw SQLiteException.New (r, r.ToString ()); + } + } + + protected virtual Sqlite3Statement Prepare () + { + var stmt = SQLite3.Prepare2 (Connection.Handle, CommandText); + return stmt; + } + + public void Dispose () + { + Dispose (true); + GC.SuppressFinalize (this); + } + + private void Dispose (bool disposing) + { + if (Statement != NullStatement) { + try { + SQLite3.Finalize (Statement); + } finally { + Statement = NullStatement; + Connection = null; + } + } + } + + ~PreparedSqlLiteInsertCommand () + { + Dispose (false); + } + } + + public abstract class BaseTableQuery + { + protected class Ordering + { + public string ColumnName { get; set; } + public bool Ascending { get; set; } + } + } + + public class TableQuery : BaseTableQuery, IEnumerable + { + public SQLiteConnection Connection { get; private set; } + + public TableMapping Table { get; private set; } + + Expression _where; + List _orderBys; + int? _limit; + int? _offset; + + BaseTableQuery _joinInner; + Expression _joinInnerKeySelector; + BaseTableQuery _joinOuter; + Expression _joinOuterKeySelector; + Expression _joinSelector; + + Expression _selector; + + TableQuery (SQLiteConnection conn, TableMapping table) + { + Connection = conn; + Table = table; + } + + public TableQuery (SQLiteConnection conn) + { + Connection = conn; + Table = Connection.GetMapping (typeof(T)); + } + + public TableQuery Clone () + { + var q = new TableQuery (Connection, Table); + q._where = _where; + q._deferred = _deferred; + if (_orderBys != null) { + q._orderBys = new List (_orderBys); + } + q._limit = _limit; + q._offset = _offset; + q._joinInner = _joinInner; + q._joinInnerKeySelector = _joinInnerKeySelector; + q._joinOuter = _joinOuter; + q._joinOuterKeySelector = _joinOuterKeySelector; + q._joinSelector = _joinSelector; + q._selector = _selector; + return q; + } + + public TableQuery Where (Expression> predExpr) + { + if (predExpr.NodeType == ExpressionType.Lambda) { + var lambda = (LambdaExpression)predExpr; + var pred = lambda.Body; + var q = Clone (); + q.AddWhere (pred); + return q; + } else { + throw new NotSupportedException ("Must be a predicate"); + } + } + + public TableQuery Take (int n) + { + var q = Clone (); + q._limit = n; + return q; + } + + public TableQuery Skip (int n) + { + var q = Clone (); + q._offset = n; + return q; + } + + public T ElementAt (int index) + { + return Skip (index).Take (1).First (); + } + + bool _deferred; + public TableQuery Deferred () + { + var q = Clone (); + q._deferred = true; + return q; + } + + public TableQuery OrderBy (Expression> orderExpr) + { + return AddOrderBy (orderExpr, true); + } + + public TableQuery OrderByDescending (Expression> orderExpr) + { + return AddOrderBy (orderExpr, false); + } + + public TableQuery ThenBy(Expression> orderExpr) + { + return AddOrderBy(orderExpr, true); + } + + public TableQuery ThenByDescending(Expression> orderExpr) + { + return AddOrderBy(orderExpr, false); + } + + private TableQuery AddOrderBy (Expression> orderExpr, bool asc) + { + if (orderExpr.NodeType == ExpressionType.Lambda) { + var lambda = (LambdaExpression)orderExpr; + + MemberExpression mem = null; + + var unary = lambda.Body as UnaryExpression; + if (unary != null && unary.NodeType == ExpressionType.Convert) { + mem = unary.Operand as MemberExpression; + } + else { + mem = lambda.Body as MemberExpression; + } + + if (mem != null && (mem.Expression.NodeType == ExpressionType.Parameter)) { + var q = Clone (); + if (q._orderBys == null) { + q._orderBys = new List (); + } + q._orderBys.Add (new Ordering { + ColumnName = Table.FindColumnWithPropertyName(mem.Member.Name).Name, + Ascending = asc + }); + return q; + } else { + throw new NotSupportedException ("Order By does not support: " + orderExpr); + } + } else { + throw new NotSupportedException ("Must be a predicate"); + } + } + + private void AddWhere (Expression pred) + { + if (_where == null) { + _where = pred; + } else { + _where = Expression.AndAlso (_where, pred); + } + } + + public TableQuery Join ( + TableQuery inner, + Expression> outerKeySelector, + Expression> innerKeySelector, + Expression> resultSelector) + { + var q = new TableQuery (Connection, Connection.GetMapping (typeof (TResult))) { + _joinOuter = this, + _joinOuterKeySelector = outerKeySelector, + _joinInner = inner, + _joinInnerKeySelector = innerKeySelector, + _joinSelector = resultSelector, + }; + return q; + } + + public TableQuery Select (Expression> selector) + { + var q = Clone (); + q._selector = selector; + return q; + } + + private SQLiteCommand GenerateCommand (string selectionList) + { + if (_joinInner != null && _joinOuter != null) { + throw new NotSupportedException ("Joins are not supported."); + } + else { + var cmdText = "select " + selectionList + " from \"" + Table.TableName + "\""; + var args = new List (); + if (_where != null) { + var w = CompileExpr (_where, args); + cmdText += " where " + w.CommandText; + } + if ((_orderBys != null) && (_orderBys.Count > 0)) { + var t = string.Join (", ", _orderBys.Select (o => "\"" + o.ColumnName + "\"" + (o.Ascending ? "" : " desc")).ToArray ()); + cmdText += " order by " + t; + } + if (_limit.HasValue) { + cmdText += " limit " + _limit.Value; + } + if (_offset.HasValue) { + if (!_limit.HasValue) { + cmdText += " limit -1 "; + } + cmdText += " offset " + _offset.Value; + } + return Connection.CreateCommand (cmdText, args.ToArray ()); + } + } + + class CompileResult + { + public string CommandText { get; set; } + + public object Value { get; set; } + } + + private CompileResult CompileExpr (Expression expr, List queryArgs) + { + if (expr == null) { + throw new NotSupportedException ("Expression is NULL"); + } else if (expr is BinaryExpression) { + var bin = (BinaryExpression)expr; + + var leftr = CompileExpr (bin.Left, queryArgs); + var rightr = CompileExpr (bin.Right, queryArgs); + + //If either side is a parameter and is null, then handle the other side specially (for "is null"/"is not null") + string text; + if (leftr.CommandText == "?" && leftr.Value == null) + text = CompileNullBinaryExpression(bin, rightr); + else if (rightr.CommandText == "?" && rightr.Value == null) + text = CompileNullBinaryExpression(bin, leftr); + else + text = "(" + leftr.CommandText + " " + GetSqlName(bin) + " " + rightr.CommandText + ")"; + return new CompileResult { CommandText = text }; + } else if (expr.NodeType == ExpressionType.Call) { + + var call = (MethodCallExpression)expr; + var args = new CompileResult[call.Arguments.Count]; + var obj = call.Object != null ? CompileExpr (call.Object, queryArgs) : null; + + for (var i = 0; i < args.Length; i++) { + args [i] = CompileExpr (call.Arguments [i], queryArgs); + } + + var sqlCall = ""; + + if (call.Method.Name == "Like" && args.Length == 2) { + sqlCall = "(" + args [0].CommandText + " like " + args [1].CommandText + ")"; + } + else if (call.Method.Name == "Contains" && args.Length == 2) { + sqlCall = "(" + args [1].CommandText + " in " + args [0].CommandText + ")"; + } + else if (call.Method.Name == "Contains" && args.Length == 1) { + if (call.Object != null && call.Object.Type == typeof(string)) { + sqlCall = "(" + obj.CommandText + " like ('%' || " + args [0].CommandText + " || '%'))"; + } + else { + sqlCall = "(" + args [0].CommandText + " in " + obj.CommandText + ")"; + } + } + else if (call.Method.Name == "StartsWith" && args.Length == 1) { + sqlCall = "(" + obj.CommandText + " like (" + args [0].CommandText + " || '%'))"; + } + else if (call.Method.Name == "EndsWith" && args.Length == 1) { + sqlCall = "(" + obj.CommandText + " like ('%' || " + args [0].CommandText + "))"; + } + else if (call.Method.Name == "Equals" && args.Length == 1) { + sqlCall = "(" + obj.CommandText + " = (" + args[0].CommandText + "))"; + } else if (call.Method.Name == "ToLower") { + sqlCall = "(lower(" + obj.CommandText + "))"; + } else if (call.Method.Name == "ToUpper") { + sqlCall = "(upper(" + obj.CommandText + "))"; + } else { + sqlCall = call.Method.Name.ToLower () + "(" + string.Join (",", args.Select (a => a.CommandText).ToArray ()) + ")"; + } + return new CompileResult { CommandText = sqlCall }; + + } else if (expr.NodeType == ExpressionType.Constant) { + var c = (ConstantExpression)expr; + queryArgs.Add (c.Value); + return new CompileResult { + CommandText = "?", + Value = c.Value + }; + } else if (expr.NodeType == ExpressionType.Convert) { + var u = (UnaryExpression)expr; + var ty = u.Type; + var valr = CompileExpr (u.Operand, queryArgs); + return new CompileResult { + CommandText = valr.CommandText, + Value = valr.Value != null ? ConvertTo (valr.Value, ty) : null + }; + } else if (expr.NodeType == ExpressionType.MemberAccess) { + var mem = (MemberExpression)expr; + + if (mem.Expression!=null && mem.Expression.NodeType == ExpressionType.Parameter) { + // + // This is a column of our table, output just the column name + // Need to translate it if that column name is mapped + // + var columnName = Table.FindColumnWithPropertyName (mem.Member.Name).Name; + return new CompileResult { CommandText = "\"" + columnName + "\"" }; + } else { + object obj = null; + if (mem.Expression != null) { + var r = CompileExpr (mem.Expression, queryArgs); + if (r.Value == null) { + throw new NotSupportedException ("Member access failed to compile expression"); + } + if (r.CommandText == "?") { + queryArgs.RemoveAt (queryArgs.Count - 1); + } + obj = r.Value; + } + + // + // Get the member value + // + object val = null; + +#if !NETFX_CORE + if (mem.Member.MemberType == MemberTypes.Property) { +#else + if (mem.Member is PropertyInfo) { +#endif + var m = (PropertyInfo)mem.Member; + val = m.GetValue (obj, null); +#if !NETFX_CORE + } else if (mem.Member.MemberType == MemberTypes.Field) { +#else + } else if (mem.Member is FieldInfo) { +#endif +#if SILVERLIGHT + val = Expression.Lambda (expr).Compile ().DynamicInvoke (); +#else + var m = (FieldInfo)mem.Member; + val = m.GetValue (obj); +#endif + } else { +#if !NETFX_CORE + throw new NotSupportedException ("MemberExpr: " + mem.Member.MemberType); +#else + throw new NotSupportedException ("MemberExpr: " + mem.Member.DeclaringType); +#endif + } + + // + // Work special magic for enumerables + // + if (val != null && val is System.Collections.IEnumerable && !(val is string) && !(val is System.Collections.Generic.IEnumerable)) { + var sb = new System.Text.StringBuilder(); + sb.Append("("); + var head = ""; + foreach (var a in (System.Collections.IEnumerable)val) { + queryArgs.Add(a); + sb.Append(head); + sb.Append("?"); + head = ","; + } + sb.Append(")"); + return new CompileResult { + CommandText = sb.ToString(), + Value = val + }; + } + else { + queryArgs.Add (val); + return new CompileResult { + CommandText = "?", + Value = val + }; + } + } + } + throw new NotSupportedException ("Cannot compile: " + expr.NodeType.ToString ()); + } + + static object ConvertTo (object obj, Type t) + { + Type nut = Nullable.GetUnderlyingType(t); + + if (nut != null) { + if (obj == null) return null; + return Convert.ChangeType (obj, nut); + } else { + return Convert.ChangeType (obj, t); + } + } + + /// + /// Compiles a BinaryExpression where one of the parameters is null. + /// + /// The non-null parameter + private string CompileNullBinaryExpression(BinaryExpression expression, CompileResult parameter) + { + if (expression.NodeType == ExpressionType.Equal) + return "(" + parameter.CommandText + " is ?)"; + else if (expression.NodeType == ExpressionType.NotEqual) + return "(" + parameter.CommandText + " is not ?)"; + else + throw new NotSupportedException("Cannot compile Null-BinaryExpression with type " + expression.NodeType.ToString()); + } + + string GetSqlName (Expression expr) + { + var n = expr.NodeType; + if (n == ExpressionType.GreaterThan) + return ">"; else if (n == ExpressionType.GreaterThanOrEqual) { + return ">="; + } else if (n == ExpressionType.LessThan) { + return "<"; + } else if (n == ExpressionType.LessThanOrEqual) { + return "<="; + } else if (n == ExpressionType.And) { + return "&"; + } else if (n == ExpressionType.AndAlso) { + return "and"; + } else if (n == ExpressionType.Or) { + return "|"; + } else if (n == ExpressionType.OrElse) { + return "or"; + } else if (n == ExpressionType.Equal) { + return "="; + } else if (n == ExpressionType.NotEqual) { + return "!="; + } else { + throw new NotSupportedException ("Cannot get SQL for: " + n); + } + } + + public int Count () + { + return GenerateCommand("count(*)").ExecuteScalar (); + } + + public int Count (Expression> predExpr) + { + return Where (predExpr).Count (); + } + + public IEnumerator GetEnumerator () + { + if (!_deferred) + return GenerateCommand("*").ExecuteQuery().GetEnumerator(); + + return GenerateCommand("*").ExecuteDeferredQuery().GetEnumerator(); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator () + { + return GetEnumerator (); + } + + public T First () + { + var query = Take (1); + return query.ToList().First (); + } + + public T FirstOrDefault () + { + var query = Take (1); + return query.ToList().FirstOrDefault (); + } + } + + public static class SQLite3 + { + public enum Result : int + { + OK = 0, + Error = 1, + Internal = 2, + Perm = 3, + Abort = 4, + Busy = 5, + Locked = 6, + NoMem = 7, + ReadOnly = 8, + Interrupt = 9, + IOError = 10, + Corrupt = 11, + NotFound = 12, + Full = 13, + CannotOpen = 14, + LockErr = 15, + Empty = 16, + SchemaChngd = 17, + TooBig = 18, + Constraint = 19, + Mismatch = 20, + Misuse = 21, + NotImplementedLFS = 22, + AccessDenied = 23, + Format = 24, + Range = 25, + NonDBFile = 26, + Notice = 27, + Warning = 28, + Row = 100, + Done = 101 + } + + public enum ExtendedResult : int + { + IOErrorRead = (Result.IOError | (1 << 8)), + IOErrorShortRead = (Result.IOError | (2 << 8)), + IOErrorWrite = (Result.IOError | (3 << 8)), + IOErrorFsync = (Result.IOError | (4 << 8)), + IOErrorDirFSync = (Result.IOError | (5 << 8)), + IOErrorTruncate = (Result.IOError | (6 << 8)), + IOErrorFStat = (Result.IOError | (7 << 8)), + IOErrorUnlock = (Result.IOError | (8 << 8)), + IOErrorRdlock = (Result.IOError | (9 << 8)), + IOErrorDelete = (Result.IOError | (10 << 8)), + IOErrorBlocked = (Result.IOError | (11 << 8)), + IOErrorNoMem = (Result.IOError | (12 << 8)), + IOErrorAccess = (Result.IOError | (13 << 8)), + IOErrorCheckReservedLock = (Result.IOError | (14 << 8)), + IOErrorLock = (Result.IOError | (15 << 8)), + IOErrorClose = (Result.IOError | (16 << 8)), + IOErrorDirClose = (Result.IOError | (17 << 8)), + IOErrorSHMOpen = (Result.IOError | (18 << 8)), + IOErrorSHMSize = (Result.IOError | (19 << 8)), + IOErrorSHMLock = (Result.IOError | (20 << 8)), + IOErrorSHMMap = (Result.IOError | (21 << 8)), + IOErrorSeek = (Result.IOError | (22 << 8)), + IOErrorDeleteNoEnt = (Result.IOError | (23 << 8)), + IOErrorMMap = (Result.IOError | (24 << 8)), + LockedSharedcache = (Result.Locked | (1 << 8)), + BusyRecovery = (Result.Busy | (1 << 8)), + CannottOpenNoTempDir = (Result.CannotOpen | (1 << 8)), + CannotOpenIsDir = (Result.CannotOpen | (2 << 8)), + CannotOpenFullPath = (Result.CannotOpen | (3 << 8)), + CorruptVTab = (Result.Corrupt | (1 << 8)), + ReadonlyRecovery = (Result.ReadOnly | (1 << 8)), + ReadonlyCannotLock = (Result.ReadOnly | (2 << 8)), + ReadonlyRollback = (Result.ReadOnly | (3 << 8)), + AbortRollback = (Result.Abort | (2 << 8)), + ConstraintCheck = (Result.Constraint | (1 << 8)), + ConstraintCommitHook = (Result.Constraint | (2 << 8)), + ConstraintForeignKey = (Result.Constraint | (3 << 8)), + ConstraintFunction = (Result.Constraint | (4 << 8)), + ConstraintNotNull = (Result.Constraint | (5 << 8)), + ConstraintPrimaryKey = (Result.Constraint | (6 << 8)), + ConstraintTrigger = (Result.Constraint | (7 << 8)), + ConstraintUnique = (Result.Constraint | (8 << 8)), + ConstraintVTab = (Result.Constraint | (9 << 8)), + NoticeRecoverWAL = (Result.Notice | (1 << 8)), + NoticeRecoverRollback = (Result.Notice | (2 << 8)) + } + + + public enum ConfigOption : int + { + SingleThread = 1, + MultiThread = 2, + Serialized = 3 + } + +#if !USE_CSHARP_SQLITE && !USE_WP8_NATIVE_SQLITE + [DllImport("sqlite3", EntryPoint = "sqlite3_open", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Open ([MarshalAs(UnmanagedType.LPStr)] string filename, out IntPtr db); + + [DllImport("sqlite3", EntryPoint = "sqlite3_open_v2", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Open ([MarshalAs(UnmanagedType.LPStr)] string filename, out IntPtr db, int flags, IntPtr zvfs); + + [DllImport("sqlite3", EntryPoint = "sqlite3_open_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open(byte[] filename, out IntPtr db, int flags, IntPtr zvfs); + + [DllImport("sqlite3", EntryPoint = "sqlite3_open16", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open16([MarshalAs(UnmanagedType.LPWStr)] string filename, out IntPtr db); + + [DllImport("sqlite3", EntryPoint = "sqlite3_enable_load_extension", CallingConvention=CallingConvention.Cdecl)] + public static extern Result EnableLoadExtension (IntPtr db, int onoff); + + [DllImport("sqlite3", EntryPoint = "sqlite3_close", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Close (IntPtr db); + + [DllImport("sqlite3", EntryPoint = "sqlite3_initialize", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Initialize(); + + [DllImport("sqlite3", EntryPoint = "sqlite3_shutdown", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Shutdown(); + + [DllImport("sqlite3", EntryPoint = "sqlite3_config", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Config (ConfigOption option); + + [DllImport("sqlite3", EntryPoint = "sqlite3_win32_set_directory", CallingConvention=CallingConvention.Cdecl, CharSet=CharSet.Unicode)] + public static extern int SetDirectory (uint directoryType, string directoryPath); + + [DllImport("sqlite3", EntryPoint = "sqlite3_busy_timeout", CallingConvention=CallingConvention.Cdecl)] + public static extern Result BusyTimeout (IntPtr db, int milliseconds); + + [DllImport("sqlite3", EntryPoint = "sqlite3_changes", CallingConvention=CallingConvention.Cdecl)] + public static extern int Changes (IntPtr db); + + [DllImport("sqlite3", EntryPoint = "sqlite3_prepare_v2", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Prepare2 (IntPtr db, [MarshalAs(UnmanagedType.LPStr)] string sql, int numBytes, out IntPtr stmt, IntPtr pzTail); + +#if NETFX_CORE + [DllImport ("sqlite3", EntryPoint = "sqlite3_prepare_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Prepare2 (IntPtr db, byte[] queryBytes, int numBytes, out IntPtr stmt, IntPtr pzTail); +#endif + + public static IntPtr Prepare2 (IntPtr db, string query) + { + IntPtr stmt; +#if NETFX_CORE + byte[] queryBytes = System.Text.UTF8Encoding.UTF8.GetBytes (query); + var r = Prepare2 (db, queryBytes, queryBytes.Length, out stmt, IntPtr.Zero); +#else + var r = Prepare2 (db, query, System.Text.UTF8Encoding.UTF8.GetByteCount (query), out stmt, IntPtr.Zero); +#endif + if (r != Result.OK) { + throw SQLiteException.New (r, GetErrmsg (db)); + } + return stmt; + } + + [DllImport("sqlite3", EntryPoint = "sqlite3_step", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Step (IntPtr stmt); + + [DllImport("sqlite3", EntryPoint = "sqlite3_reset", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Reset (IntPtr stmt); + + [DllImport("sqlite3", EntryPoint = "sqlite3_finalize", CallingConvention=CallingConvention.Cdecl)] + public static extern Result Finalize (IntPtr stmt); + + [DllImport("sqlite3", EntryPoint = "sqlite3_last_insert_rowid", CallingConvention=CallingConvention.Cdecl)] + public static extern long LastInsertRowid (IntPtr db); + + [DllImport("sqlite3", EntryPoint = "sqlite3_errmsg16", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr Errmsg (IntPtr db); + + public static string GetErrmsg (IntPtr db) + { + return Marshal.PtrToStringUni (Errmsg (db)); + } + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_parameter_index", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindParameterIndex (IntPtr stmt, [MarshalAs(UnmanagedType.LPStr)] string name); + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_null", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindNull (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_int", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindInt (IntPtr stmt, int index, int val); + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_int64", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindInt64 (IntPtr stmt, int index, long val); + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_double", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindDouble (IntPtr stmt, int index, double val); + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_text16", CallingConvention=CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + public static extern int BindText (IntPtr stmt, int index, [MarshalAs(UnmanagedType.LPWStr)] string val, int n, IntPtr free); + + [DllImport("sqlite3", EntryPoint = "sqlite3_bind_blob", CallingConvention=CallingConvention.Cdecl)] + public static extern int BindBlob (IntPtr stmt, int index, byte[] val, int n, IntPtr free); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_count", CallingConvention=CallingConvention.Cdecl)] + public static extern int ColumnCount (IntPtr stmt); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_name", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnName (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_name16", CallingConvention=CallingConvention.Cdecl)] + static extern IntPtr ColumnName16Internal (IntPtr stmt, int index); + public static string ColumnName16(IntPtr stmt, int index) + { + return Marshal.PtrToStringUni(ColumnName16Internal(stmt, index)); + } + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_type", CallingConvention=CallingConvention.Cdecl)] + public static extern ColType ColumnType (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_int", CallingConvention=CallingConvention.Cdecl)] + public static extern int ColumnInt (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_int64", CallingConvention=CallingConvention.Cdecl)] + public static extern long ColumnInt64 (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_double", CallingConvention=CallingConvention.Cdecl)] + public static extern double ColumnDouble (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_text", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnText (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_text16", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnText16 (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_blob", CallingConvention=CallingConvention.Cdecl)] + public static extern IntPtr ColumnBlob (IntPtr stmt, int index); + + [DllImport("sqlite3", EntryPoint = "sqlite3_column_bytes", CallingConvention=CallingConvention.Cdecl)] + public static extern int ColumnBytes (IntPtr stmt, int index); + + public static string ColumnString (IntPtr stmt, int index) + { + return Marshal.PtrToStringUni (SQLite3.ColumnText16 (stmt, index)); + } + + public static byte[] ColumnByteArray (IntPtr stmt, int index) + { + int length = ColumnBytes (stmt, index); + var result = new byte[length]; + if (length > 0) + Marshal.Copy (ColumnBlob (stmt, index), result, 0, length); + return result; + } + + [DllImport ("sqlite3", EntryPoint = "sqlite3_extended_errcode", CallingConvention = CallingConvention.Cdecl)] + public static extern ExtendedResult ExtendedErrCode (IntPtr db); + + [DllImport ("sqlite3", EntryPoint = "sqlite3_libversion_number", CallingConvention = CallingConvention.Cdecl)] + public static extern int LibVersionNumber (); +#else + public static Result Open(string filename, out Sqlite3DatabaseHandle db) + { + return (Result) Sqlite3.sqlite3_open(filename, out db); + } + + public static Result Open(string filename, out Sqlite3DatabaseHandle db, int flags, IntPtr zVfs) + { +#if USE_WP8_NATIVE_SQLITE + return (Result)Sqlite3.sqlite3_open_v2(filename, out db, flags, ""); +#else + return (Result)Sqlite3.sqlite3_open_v2(filename, out db, flags, null); +#endif + } + + public static Result Close(Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_close(db); + } + + public static Result BusyTimeout(Sqlite3DatabaseHandle db, int milliseconds) + { + return (Result)Sqlite3.sqlite3_busy_timeout(db, milliseconds); + } + + public static int Changes(Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_changes(db); + } + + public static Sqlite3Statement Prepare2(Sqlite3DatabaseHandle db, string query) + { + Sqlite3Statement stmt = default(Sqlite3Statement); +#if USE_WP8_NATIVE_SQLITE + var r = Sqlite3.sqlite3_prepare_v2(db, query, out stmt); +#else + stmt = new Sqlite3Statement(); + var r = Sqlite3.sqlite3_prepare_v2(db, query, -1, ref stmt, 0); +#endif + if (r != 0) + { + throw SQLiteException.New((Result)r, GetErrmsg(db)); + } + return stmt; + } + + public static Result Step(Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_step(stmt); + } + + public static Result Reset(Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_reset(stmt); + } + + public static Result Finalize(Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_finalize(stmt); + } + + public static long LastInsertRowid(Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_last_insert_rowid(db); + } + + public static string GetErrmsg(Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_errmsg(db); + } + + public static int BindParameterIndex(Sqlite3Statement stmt, string name) + { + return Sqlite3.sqlite3_bind_parameter_index(stmt, name); + } + + public static int BindNull(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_bind_null(stmt, index); + } + + public static int BindInt(Sqlite3Statement stmt, int index, int val) + { + return Sqlite3.sqlite3_bind_int(stmt, index, val); + } + + public static int BindInt64(Sqlite3Statement stmt, int index, long val) + { + return Sqlite3.sqlite3_bind_int64(stmt, index, val); + } + + public static int BindDouble(Sqlite3Statement stmt, int index, double val) + { + return Sqlite3.sqlite3_bind_double(stmt, index, val); + } + + public static int BindText(Sqlite3Statement stmt, int index, string val, int n, IntPtr free) + { +#if USE_WP8_NATIVE_SQLITE + return Sqlite3.sqlite3_bind_text(stmt, index, val, n); +#else + return Sqlite3.sqlite3_bind_text(stmt, index, val, n, null); +#endif + } + + public static int BindBlob(Sqlite3Statement stmt, int index, byte[] val, int n, IntPtr free) + { +#if USE_WP8_NATIVE_SQLITE + return Sqlite3.sqlite3_bind_blob(stmt, index, val, n); +#else + return Sqlite3.sqlite3_bind_blob(stmt, index, val, n, null); +#endif + } + + public static int ColumnCount(Sqlite3Statement stmt) + { + return Sqlite3.sqlite3_column_count(stmt); + } + + public static string ColumnName(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_name(stmt, index); + } + + public static string ColumnName16(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_name(stmt, index); + } + + public static ColType ColumnType(Sqlite3Statement stmt, int index) + { + return (ColType)Sqlite3.sqlite3_column_type(stmt, index); + } + + public static int ColumnInt(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_int(stmt, index); + } + + public static long ColumnInt64(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_int64(stmt, index); + } + + public static double ColumnDouble(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_double(stmt, index); + } + + public static string ColumnText(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text(stmt, index); + } + + public static string ColumnText16(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text(stmt, index); + } + + public static byte[] ColumnBlob(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_blob(stmt, index); + } + + public static int ColumnBytes(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_bytes(stmt, index); + } + + public static string ColumnString(Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text(stmt, index); + } + + public static byte[] ColumnByteArray(Sqlite3Statement stmt, int index) + { + return ColumnBlob(stmt, index); + } + + public static Result EnableLoadExtension(Sqlite3DatabaseHandle db, int onoff) + { + return (Result)Sqlite3.sqlite3_enable_load_extension(db, onoff); + } + + public static ExtendedResult ExtendedErrCode(Sqlite3DatabaseHandle db) + { + return (ExtendedResult)Sqlite3.sqlite3_extended_errcode(db); + } +#endif + + public enum ColType : int + { + Integer = 1, + Float = 2, + Text = 3, + Blob = 4, + Null = 5 + } + } +} diff --git a/NadekoBot/SQLiteAsync.cs b/NadekoBot/SQLiteAsync.cs new file mode 100644 index 00000000..79b91cba --- /dev/null +++ b/NadekoBot/SQLiteAsync.cs @@ -0,0 +1,503 @@ +// +// Copyright (c) 2012 Krueger Systems, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; + +namespace SQLite +{ + public partial class SQLiteAsyncConnection + { + SQLiteConnectionString _connectionString; + SQLiteOpenFlags _openFlags; + + public SQLiteAsyncConnection(string databasePath, bool storeDateTimeAsTicks = false) + : this(databasePath, SQLiteOpenFlags.ReadWrite | SQLiteOpenFlags.Create, storeDateTimeAsTicks) + { + } + + public SQLiteAsyncConnection(string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks = false) + { + _openFlags = openFlags; + _connectionString = new SQLiteConnectionString(databasePath, storeDateTimeAsTicks); + } + + SQLiteConnectionWithLock GetConnection () + { + return SQLiteConnectionPool.Shared.GetConnection (_connectionString, _openFlags); + } + + public Task CreateTableAsync () + where T : new () + { + return CreateTablesAsync (typeof (T)); + } + + public Task CreateTablesAsync () + where T : new () + where T2 : new () + { + return CreateTablesAsync (typeof (T), typeof (T2)); + } + + public Task CreateTablesAsync () + where T : new () + where T2 : new () + where T3 : new () + { + return CreateTablesAsync (typeof (T), typeof (T2), typeof (T3)); + } + + public Task CreateTablesAsync () + where T : new () + where T2 : new () + where T3 : new () + where T4 : new () + { + return CreateTablesAsync (typeof (T), typeof (T2), typeof (T3), typeof (T4)); + } + + public Task CreateTablesAsync () + where T : new () + where T2 : new () + where T3 : new () + where T4 : new () + where T5 : new () + { + return CreateTablesAsync (typeof (T), typeof (T2), typeof (T3), typeof (T4), typeof (T5)); + } + + public Task CreateTablesAsync (params Type[] types) + { + return Task.Factory.StartNew (() => { + CreateTablesResult result = new CreateTablesResult (); + var conn = GetConnection (); + using (conn.Lock ()) { + foreach (Type type in types) { + int aResult = conn.CreateTable (type); + result.Results[type] = aResult; + } + } + return result; + }); + } + + public Task DropTableAsync () + where T : new () + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.DropTable (); + } + }); + } + + public Task InsertAsync (object item) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Insert (item); + } + }); + } + + public Task UpdateAsync (object item) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Update (item); + } + }); + } + + public Task DeleteAsync (object item) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Delete (item); + } + }); + } + + public Task GetAsync(object pk) + where T : new() + { + return Task.Factory.StartNew(() => + { + var conn = GetConnection(); + using (conn.Lock()) + { + return conn.Get(pk); + } + }); + } + + public Task FindAsync (object pk) + where T : new () + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Find (pk); + } + }); + } + + public Task GetAsync (Expression> predicate) + where T : new() + { + return Task.Factory.StartNew(() => + { + var conn = GetConnection(); + using (conn.Lock()) + { + return conn.Get (predicate); + } + }); + } + + public Task FindAsync (Expression> predicate) + where T : new () + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Find (predicate); + } + }); + } + + public Task ExecuteAsync (string query, params object[] args) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Execute (query, args); + } + }); + } + + public Task InsertAllAsync (IEnumerable items) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.InsertAll (items); + } + }); + } + + public Task UpdateAllAsync (IEnumerable items) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.UpdateAll (items); + } + }); + } + + [Obsolete("Will cause a deadlock if any call in action ends up in a different thread. Use RunInTransactionAsync(Action) instead.")] + public Task RunInTransactionAsync (Action action) + { + return Task.Factory.StartNew (() => { + var conn = this.GetConnection (); + using (conn.Lock ()) { + conn.BeginTransaction (); + try { + action (this); + conn.Commit (); + } + catch (Exception) { + conn.Rollback (); + throw; + } + } + }); + } + + public Task RunInTransactionAsync(Action action) + { + return Task.Factory.StartNew(() => + { + var conn = this.GetConnection(); + using (conn.Lock()) + { + conn.BeginTransaction(); + try + { + action(conn); + conn.Commit(); + } + catch (Exception) + { + conn.Rollback(); + throw; + } + } + }); + } + + public AsyncTableQuery Table () + where T : new () + { + // + // This isn't async as the underlying connection doesn't go out to the database + // until the query is performed. The Async methods are on the query iteself. + // + var conn = GetConnection (); + return new AsyncTableQuery (conn.Table ()); + } + + public Task ExecuteScalarAsync (string sql, params object[] args) + { + return Task.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + var command = conn.CreateCommand (sql, args); + return command.ExecuteScalar (); + } + }); + } + + public Task> QueryAsync (string sql, params object[] args) + where T : new () + { + return Task>.Factory.StartNew (() => { + var conn = GetConnection (); + using (conn.Lock ()) { + return conn.Query (sql, args); + } + }); + } + } + + // + // TODO: Bind to AsyncConnection.GetConnection instead so that delayed + // execution can still work after a Pool.Reset. + // + public class AsyncTableQuery + where T : new () + { + TableQuery _innerQuery; + + public AsyncTableQuery (TableQuery innerQuery) + { + _innerQuery = innerQuery; + } + + public AsyncTableQuery Where (Expression> predExpr) + { + return new AsyncTableQuery (_innerQuery.Where (predExpr)); + } + + public AsyncTableQuery Skip (int n) + { + return new AsyncTableQuery (_innerQuery.Skip (n)); + } + + public AsyncTableQuery Take (int n) + { + return new AsyncTableQuery (_innerQuery.Take (n)); + } + + public AsyncTableQuery OrderBy (Expression> orderExpr) + { + return new AsyncTableQuery (_innerQuery.OrderBy (orderExpr)); + } + + public AsyncTableQuery OrderByDescending (Expression> orderExpr) + { + return new AsyncTableQuery (_innerQuery.OrderByDescending (orderExpr)); + } + + public Task> ToListAsync () + { + return Task.Factory.StartNew (() => { + using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { + return _innerQuery.ToList (); + } + }); + } + + public Task CountAsync () + { + return Task.Factory.StartNew (() => { + using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { + return _innerQuery.Count (); + } + }); + } + + public Task ElementAtAsync (int index) + { + return Task.Factory.StartNew (() => { + using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { + return _innerQuery.ElementAt (index); + } + }); + } + + public Task FirstAsync () + { + return Task.Factory.StartNew(() => { + using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { + return _innerQuery.First (); + } + }); + } + + public Task FirstOrDefaultAsync () + { + return Task.Factory.StartNew(() => { + using (((SQLiteConnectionWithLock)_innerQuery.Connection).Lock ()) { + return _innerQuery.FirstOrDefault (); + } + }); + } + } + + public class CreateTablesResult + { + public Dictionary Results { get; private set; } + + internal CreateTablesResult () + { + this.Results = new Dictionary (); + } + } + + class SQLiteConnectionPool + { + class Entry + { + public SQLiteConnectionString ConnectionString { get; private set; } + public SQLiteConnectionWithLock Connection { get; private set; } + + public Entry (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + { + ConnectionString = connectionString; + Connection = new SQLiteConnectionWithLock (connectionString, openFlags); + } + + public void OnApplicationSuspended () + { + Connection.Dispose (); + Connection = null; + } + } + + readonly Dictionary _entries = new Dictionary (); + readonly object _entriesLock = new object (); + + static readonly SQLiteConnectionPool _shared = new SQLiteConnectionPool (); + + /// + /// Gets the singleton instance of the connection tool. + /// + public static SQLiteConnectionPool Shared + { + get + { + return _shared; + } + } + + public SQLiteConnectionWithLock GetConnection (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + { + lock (_entriesLock) { + Entry entry; + string key = connectionString.ConnectionString; + + if (!_entries.TryGetValue (key, out entry)) { + entry = new Entry (connectionString, openFlags); + _entries[key] = entry; + } + + return entry.Connection; + } + } + + /// + /// Closes all connections managed by this pool. + /// + public void Reset () + { + lock (_entriesLock) { + foreach (var entry in _entries.Values) { + entry.OnApplicationSuspended (); + } + _entries.Clear (); + } + } + + /// + /// Call this method when the application is suspended. + /// + /// Behaviour here is to close any open connections. + public void ApplicationSuspended () + { + Reset (); + } + } + + class SQLiteConnectionWithLock : SQLiteConnection + { + readonly object _lockPoint = new object (); + + public SQLiteConnectionWithLock (SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + : base (connectionString.DatabasePath, openFlags, connectionString.StoreDateTimeAsTicks) + { + } + + public IDisposable Lock () + { + return new LockWrapper (_lockPoint); + } + + private class LockWrapper : IDisposable + { + object _lockPoint; + + public LockWrapper (object lockPoint) + { + _lockPoint = lockPoint; + Monitor.Enter (_lockPoint); + } + + public void Dispose () + { + Monitor.Exit (_lockPoint); + } + } + } +} + diff --git a/NadekoBot/packages.config b/NadekoBot/packages.config index c28cfd27..ec861bf0 100644 --- a/NadekoBot/packages.config +++ b/NadekoBot/packages.config @@ -10,6 +10,7 @@ + \ No newline at end of file