Packages can be loaded/unloaded. IUnloadableService interface added whose method Unload, if service implements it, will be called when the module is unloaded.
This commit is contained in:
		| @@ -1,4 +1,6 @@ | ||||
| namespace NadekoBot.Services | ||||
| using System.Threading.Tasks; | ||||
|  | ||||
| namespace NadekoBot.Services | ||||
| { | ||||
|     /// <summary> | ||||
|     /// All services must implement this interface in order to be auto-discovered by the DI system | ||||
| @@ -7,4 +9,12 @@ | ||||
|     { | ||||
|          | ||||
|     } | ||||
|  | ||||
|     /// <summary> | ||||
|     /// All services which require cleanup after they are unloaded must implement this interface | ||||
|     /// </summary> | ||||
|     public interface IUnloadableService | ||||
|     { | ||||
|         Task Unload(); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -61,17 +61,6 @@ namespace NadekoBot | ||||
|             if (shardId < 0) | ||||
|                 throw new ArgumentOutOfRangeException(nameof(shardId)); | ||||
|  | ||||
|             //var obj = JsonConvert.DeserializeObject<Dictionary<string, CommandData2>>(File.ReadAllText("./data/command_strings.json")) | ||||
|             //    .ToDictionary(x => x.Key, x => new CommandData2 | ||||
|             //    { | ||||
|             //        Cmd = x.Value.Cmd, | ||||
|             //        Desc = x.Value.Desc, | ||||
|             //        Usage = x.Value.Usage.Select(y => y.Substring(1, y.Length - 2)).ToArray(), | ||||
|             //    }); | ||||
|  | ||||
|             //File.WriteAllText("./data/command_strings.json", JsonConvert.SerializeObject(obj, Formatting.Indented)); | ||||
|              | ||||
|  | ||||
|             LogSetup.SetupLogger(); | ||||
|             _log = LogManager.GetCurrentClassLogger(); | ||||
|             TerribleElevatedPermissionCheck(); | ||||
| @@ -142,19 +131,17 @@ namespace NadekoBot | ||||
|                 //var localization = new Localization(_botConfig.Locale, AllGuildConfigs.ToDictionary(x => x.GuildId, x => x.Locale), Db); | ||||
|  | ||||
|                 //initialize Services | ||||
|                 Services = new NServiceProvider.ServiceProviderBuilder() | ||||
|                 Services = new NServiceProvider() | ||||
|                     .AddManual<IBotCredentials>(Credentials) | ||||
|                     .AddManual(_db) | ||||
|                     .AddManual(Client) | ||||
|                     .AddManual(CommandService) | ||||
|                     .AddManual(botConfigProvider) | ||||
|                     //.AddManual<ILocalization>(localization) | ||||
|                     .AddManual<IEnumerable<GuildConfig>>(AllGuildConfigs) //todo wrap this | ||||
|                     .AddManual<NadekoBot>(this) | ||||
|                     .AddManual<IUnitOfWork>(uow) | ||||
|                     .AddManual<IDataCache>(new RedisCache(Client.CurrentUser.Id)) | ||||
|                     .LoadFrom(Assembly.GetEntryAssembly()) | ||||
|                     .Build(); | ||||
|                     .AddManual<IDataCache>(new RedisCache(Client.CurrentUser.Id)); | ||||
|                 Services.LoadFrom(Assembly.GetAssembly(typeof(CommandHandler))); | ||||
|  | ||||
|                 var commandHandler = Services.GetService<CommandHandler>(); | ||||
|                 commandHandler.AddServices(Services); | ||||
| @@ -163,12 +150,13 @@ namespace NadekoBot | ||||
|                 CommandService.AddTypeReader<PermissionAction>(new PermissionActionTypeReader()); | ||||
|                 CommandService.AddTypeReader<CommandInfo>(new CommandTypeReader()); | ||||
|                 //todo module dependency | ||||
|                 //CommandService.AddTypeReader<CommandOrCrInfo>(new CommandOrCrTypeReader()); | ||||
|                 CommandService.AddTypeReader<CommandOrCrInfo>(new CommandOrCrTypeReader()); | ||||
|                 CommandService.AddTypeReader<ModuleInfo>(new ModuleTypeReader(CommandService)); | ||||
|                 CommandService.AddTypeReader<ModuleOrCrInfo>(new ModuleOrCrTypeReader(CommandService)); | ||||
|                 CommandService.AddTypeReader<IGuild>(new GuildTypeReader(Client)); | ||||
|                 //CommandService.AddTypeReader<GuildDateTime>(new GuildDateTimeTypeReader()); | ||||
|             } | ||||
|             Services.Unload(typeof(IUnitOfWork)); // unload it after the startup | ||||
|         } | ||||
|  | ||||
|         private async Task LoginAsync(string token) | ||||
| @@ -193,7 +181,7 @@ namespace NadekoBot | ||||
|                     } | ||||
|                     finally | ||||
|                     { | ||||
|                          | ||||
|  | ||||
|                     } | ||||
|                 }); | ||||
|                 return Task.CompletedTask; | ||||
| @@ -225,8 +213,8 @@ namespace NadekoBot | ||||
|  | ||||
|         public async Task RunAsync(params string[] args) | ||||
|         { | ||||
|             if(Client.ShardId == 0) | ||||
|             _log.Info("Starting NadekoBot v" + StatsService.BotVersion); | ||||
|             if (Client.ShardId == 0) | ||||
|                 _log.Info("Starting NadekoBot v" + StatsService.BotVersion); | ||||
|  | ||||
|             var sw = Stopwatch.StartNew(); | ||||
|  | ||||
| @@ -255,7 +243,7 @@ namespace NadekoBot | ||||
| #endif | ||||
|             //unload modules which are not available on the public bot | ||||
|  | ||||
|             if(isPublicNadeko) | ||||
|             if (isPublicNadeko) | ||||
|                 CommandService | ||||
|                     .Modules | ||||
|                     .ToArray() | ||||
| @@ -372,5 +360,80 @@ namespace NadekoBot | ||||
|             var sub = Services.GetService<IDataCache>().Redis.GetSubscriber(); | ||||
|             return sub.PublishAsync(Client.CurrentUser.Id + "_status.game_set", JsonConvert.SerializeObject(obj)); | ||||
|         } | ||||
|  | ||||
|         private readonly Dictionary<string, IEnumerable<ModuleInfo>> _packageModules = new Dictionary<string, IEnumerable<ModuleInfo>>(); | ||||
|         private readonly Dictionary<string, IEnumerable<Type>> _packageTypes = new Dictionary<string, IEnumerable<Type>>(); | ||||
|         private readonly SemaphoreSlim _packageLocker = new SemaphoreSlim(1, 1); | ||||
|  | ||||
|         /// <summary> | ||||
|         /// Unloads a package | ||||
|         /// </summary> | ||||
|         /// <param name="name">Package name. Case sensitive.</param> | ||||
|         /// <returns>Whether the unload is successful.</returns> | ||||
|         public async Task<bool> UnloadPackage(string name) | ||||
|         { | ||||
|             await _packageLocker.WaitAsync().ConfigureAwait(false); | ||||
|             try | ||||
|             { | ||||
|                 if (!_packageModules.TryGetValue(name, out var modules)) | ||||
|                     return false; | ||||
|  | ||||
|                 var i = 0; | ||||
|                 foreach (var m in modules) | ||||
|                 { | ||||
|                     await CommandService.RemoveModuleAsync(m).ConfigureAwait(false); | ||||
|                     i++; | ||||
|                 } | ||||
|                 _log.Info("Unloaded {0} modules.", i); | ||||
|  | ||||
|                 if (_packageTypes.TryGetValue(name, out var types)) | ||||
|                 { | ||||
|                     i = 0; | ||||
|                     foreach (var t in types) | ||||
|                     { | ||||
|                         var obj = Services.Unload(t); | ||||
|                         if (obj is IUnloadableService s) | ||||
|                             await s.Unload().ConfigureAwait(false); | ||||
|                         i++; | ||||
|                     } | ||||
|  | ||||
|                     _log.Info("Unloaded {0} types.", i); | ||||
|                 } | ||||
|                 return true; | ||||
|             } | ||||
|             finally | ||||
|             { | ||||
|                 _packageLocker.Release(); | ||||
|             } | ||||
|         } | ||||
|         /// <summary> | ||||
|         /// Loads a package | ||||
|         /// </summary> | ||||
|         /// <param name="name">Name of the package to load. Case sensitive.</param> | ||||
|         /// <returns>Whether the load is successful.</returns> | ||||
|         public async Task<bool> LoadPackage(string name) | ||||
|         { | ||||
|             await _packageLocker.WaitAsync().ConfigureAwait(false); | ||||
|             try | ||||
|             { | ||||
|                 if (_packageModules.ContainsKey(name)) | ||||
|                     return false; | ||||
|  | ||||
|                 var package = Assembly.LoadFile(Path.Combine(AppContext.BaseDirectory, | ||||
|                                                 "modules", | ||||
|                                                 $"NadekoBot.Modules.{name}", | ||||
|                                                 $"NadekoBot.Modules.{name}.dll")); | ||||
|                 var types = Services.LoadFrom(package); | ||||
|                 var added = await CommandService.AddModulesAsync(package).ConfigureAwait(false); | ||||
|                 _log.Info("Loaded {0} modules and {1} types.", added.Count(), types.Count()); | ||||
|                 _packageModules.Add(name, added); | ||||
|                 _packageTypes.Add(name, types); | ||||
|                 return true; | ||||
|             } | ||||
|             finally | ||||
|             { | ||||
|                 _packageLocker.Release(); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| using System; | ||||
| using System.Collections; | ||||
| using System.Collections.Concurrent; | ||||
| using System.Collections.Generic; | ||||
| using System.Collections.Immutable; | ||||
| using System.Reflection; | ||||
| @@ -17,57 +16,83 @@ namespace NadekoBot.Services | ||||
|     public interface INServiceProvider : IServiceProvider, IEnumerable<object> | ||||
|     { | ||||
|         T GetService<T>(); | ||||
|         IEnumerable<Type> LoadFrom(Assembly assembly); | ||||
|         INServiceProvider AddManual<T>(T obj); | ||||
|         object Unload(Type t); | ||||
|     } | ||||
|  | ||||
|     public class NServiceProvider : INServiceProvider | ||||
|     { | ||||
|         public class ServiceProviderBuilder | ||||
|         private readonly object _locker = new object(); | ||||
|         private readonly Logger _log; | ||||
|  | ||||
|         public readonly Dictionary<Type, object> _services = new Dictionary<Type, object>(); | ||||
|         public IReadOnlyDictionary<Type, object> Services => _services; | ||||
|  | ||||
|         public NServiceProvider() | ||||
|         { | ||||
|             private ConcurrentDictionary<Type, object> _dict = new ConcurrentDictionary<Type, object>(); | ||||
|             private readonly Logger _log; | ||||
|             _log = LogManager.GetCurrentClassLogger(); | ||||
|         } | ||||
|  | ||||
|             public ServiceProviderBuilder() | ||||
|         public T GetService<T>() | ||||
|         { | ||||
|             return (T)((IServiceProvider)(this)).GetService(typeof(T)); | ||||
|         } | ||||
|  | ||||
|         object IServiceProvider.GetService(Type serviceType) | ||||
|         { | ||||
|             _services.TryGetValue(serviceType, out var toReturn); | ||||
|             return toReturn; | ||||
|         } | ||||
|  | ||||
|         public INServiceProvider AddManual<T>(T obj) | ||||
|         { | ||||
|             lock (_locker) | ||||
|             { | ||||
|                 _log = LogManager.GetCurrentClassLogger(); | ||||
|                 _services.TryAdd(typeof(T), obj); | ||||
|             } | ||||
|             return this; | ||||
|         } | ||||
|  | ||||
|             public ServiceProviderBuilder AddManual<T>(T obj) | ||||
|         public IEnumerable<Type> LoadFrom(Assembly assembly) | ||||
|         { | ||||
|             List<Type> addedTypes = new List<Type>(); | ||||
|  | ||||
|             Type[] allTypes; | ||||
|             try | ||||
|             { | ||||
|                 _dict.TryAdd(typeof(T), obj); | ||||
|                 return this; | ||||
|                 allTypes = assembly.GetTypes(); | ||||
|             } | ||||
|  | ||||
|             public NServiceProvider Build() | ||||
|             catch (ReflectionTypeLoadException ex) | ||||
|             { | ||||
|                 return new NServiceProvider(_dict); | ||||
|                 Console.WriteLine(ex.LoaderExceptions[0]); | ||||
|                 return Enumerable.Empty<Type>(); | ||||
|             } | ||||
|  | ||||
|             public ServiceProviderBuilder LoadFrom(Assembly assembly) | ||||
|             { | ||||
|                 var allTypes = assembly.GetTypes(); | ||||
|                 var services = new Queue<Type>(allTypes | ||||
|                         .Where(x => x.GetInterfaces().Contains(typeof(INService))  | ||||
|                             && !x.GetTypeInfo().IsInterface && !x.GetTypeInfo().IsAbstract | ||||
|  | ||||
|             var services = new Queue<Type>(allTypes | ||||
|                     .Where(x => x.GetInterfaces().Contains(typeof(INService)) | ||||
|                         && !x.GetTypeInfo().IsInterface && !x.GetTypeInfo().IsAbstract | ||||
| #if GLOBAL_NADEKO | ||||
|                             && x.GetTypeInfo().GetCustomAttribute<NoPublicBot>() == null | ||||
|                         && x.GetTypeInfo().GetCustomAttribute<NoPublicBot>() == null | ||||
| #endif | ||||
|                             ) | ||||
|                         .ToArray()); | ||||
|                     .ToArray()); | ||||
|  | ||||
|                 var interfaces = new HashSet<Type>(allTypes | ||||
|                         .Where(x => x.GetInterfaces().Contains(typeof(INService))  | ||||
|                             && x.GetTypeInfo().IsInterface)); | ||||
|             addedTypes.AddRange(services); | ||||
|  | ||||
|                 var alreadyFailed = new Dictionary<Type, int>(); | ||||
|             var interfaces = new HashSet<Type>(allTypes | ||||
|                     .Where(x => x.GetInterfaces().Contains(typeof(INService)) | ||||
|                         && x.GetTypeInfo().IsInterface)); | ||||
|  | ||||
|             var alreadyFailed = new Dictionary<Type, int>(); | ||||
|             lock (_locker) | ||||
|             { | ||||
|                 var sw = Stopwatch.StartNew(); | ||||
|                 var swInstance = new Stopwatch(); | ||||
|                 while (services.Count > 0) | ||||
|                 { | ||||
|                     var type = services.Dequeue(); //get a type i need to make an instance of | ||||
|  | ||||
|                     if (_dict.TryGetValue(type, out _)) // if that type is already instantiated, skip | ||||
|                     if (_services.TryGetValue(type, out _)) // if that type is already instantiated, skip | ||||
|                         continue; | ||||
|  | ||||
|                     var ctor = type.GetConstructors()[0]; | ||||
| @@ -79,7 +104,7 @@ namespace NadekoBot.Services | ||||
|                     var args = new List<object>(argTypes.Length); | ||||
|                     foreach (var arg in argTypes) //get constructor arguments from the dictionary of already instantiated types | ||||
|                     { | ||||
|                         if (_dict.TryGetValue(arg, out var argObj)) //if i got current one, add it to the list of instances and move on | ||||
|                         if (_services.TryGetValue(arg, out var argObj)) //if i got current one, add it to the list of instances and move on | ||||
|                             args.Add(argObj); | ||||
|                         else //if i failed getting it, add it to the end, and break | ||||
|                         { | ||||
| @@ -97,7 +122,7 @@ namespace NadekoBot.Services | ||||
|                     } | ||||
|                     if (args.Count != argTypes.Length) | ||||
|                         continue; | ||||
|                     // _log.Info("Loading " + type.Name); | ||||
|  | ||||
|                     swInstance.Restart(); | ||||
|                     var instance = ctor.Invoke(args.ToArray()); | ||||
|                     swInstance.Stop(); | ||||
| @@ -105,38 +130,34 @@ namespace NadekoBot.Services | ||||
|                         _log.Info($"{type.Name} took {swInstance.Elapsed.TotalSeconds:F2}s to load."); | ||||
|                     var interfaceType = interfaces.FirstOrDefault(x => instance.GetType().GetInterfaces().Contains(x)); | ||||
|                     if (interfaceType != null) | ||||
|                         _dict.TryAdd(interfaceType, instance); | ||||
|                     { | ||||
|                         addedTypes.Add(interfaceType); | ||||
|                         _services.TryAdd(interfaceType, instance); | ||||
|                     } | ||||
|  | ||||
|                     _dict.TryAdd(type, instance); | ||||
|                     _services.TryAdd(type, instance); | ||||
|                 } | ||||
|                 sw.Stop(); | ||||
|                 _log.Info($"All services loaded in {sw.Elapsed.TotalSeconds:F2}s"); | ||||
|  | ||||
|                 return this; | ||||
|             } | ||||
|             return addedTypes; | ||||
|         } | ||||
|  | ||||
|         private readonly ImmutableDictionary<Type, object> _services; | ||||
|  | ||||
|         private NServiceProvider() { } | ||||
|         public NServiceProvider(IDictionary<Type, object> services) | ||||
|         public object Unload(Type t) | ||||
|         { | ||||
|             this._services = services.ToImmutableDictionary(); | ||||
|         } | ||||
|  | ||||
|         public T GetService<T>() | ||||
|         { | ||||
|             return (T)((IServiceProvider)(this)).GetService(typeof(T)); | ||||
|         } | ||||
|  | ||||
|         object IServiceProvider.GetService(Type serviceType) | ||||
|         { | ||||
|             _services.TryGetValue(serviceType, out var toReturn); | ||||
|             return toReturn; | ||||
|             lock (_locker) | ||||
|             { | ||||
|                 if (_services.TryGetValue(t, out var obj)) | ||||
|                 { | ||||
|                     _services.Remove(t); | ||||
|                     return obj; | ||||
|                 } | ||||
|             } | ||||
|             return null; | ||||
|         } | ||||
|  | ||||
|         IEnumerator IEnumerable.GetEnumerator() => _services.Values.GetEnumerator(); | ||||
|  | ||||
|         public IEnumerator<object> GetEnumerator() => _services.Values.GetEnumerator(); | ||||
|     } | ||||
| } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user