Packages can be loaded/unloaded. IUnloadableService interface added whose method Unload, if service implements it, will be called when the module is unloaded.
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
namespace NadekoBot.Services
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace NadekoBot.Services
|
||||
{
|
||||
/// <summary>
|
||||
/// All services must implement this interface in order to be auto-discovered by the DI system
|
||||
@ -7,4 +9,12 @@
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// All services which require cleanup after they are unloaded must implement this interface
|
||||
/// </summary>
|
||||
public interface IUnloadableService
|
||||
{
|
||||
Task Unload();
|
||||
}
|
||||
}
|
||||
|
@ -61,17 +61,6 @@ namespace NadekoBot
|
||||
if (shardId < 0)
|
||||
throw new ArgumentOutOfRangeException(nameof(shardId));
|
||||
|
||||
//var obj = JsonConvert.DeserializeObject<Dictionary<string, CommandData2>>(File.ReadAllText("./data/command_strings.json"))
|
||||
// .ToDictionary(x => x.Key, x => new CommandData2
|
||||
// {
|
||||
// Cmd = x.Value.Cmd,
|
||||
// Desc = x.Value.Desc,
|
||||
// Usage = x.Value.Usage.Select(y => y.Substring(1, y.Length - 2)).ToArray(),
|
||||
// });
|
||||
|
||||
//File.WriteAllText("./data/command_strings.json", JsonConvert.SerializeObject(obj, Formatting.Indented));
|
||||
|
||||
|
||||
LogSetup.SetupLogger();
|
||||
_log = LogManager.GetCurrentClassLogger();
|
||||
TerribleElevatedPermissionCheck();
|
||||
@ -142,19 +131,17 @@ namespace NadekoBot
|
||||
//var localization = new Localization(_botConfig.Locale, AllGuildConfigs.ToDictionary(x => x.GuildId, x => x.Locale), Db);
|
||||
|
||||
//initialize Services
|
||||
Services = new NServiceProvider.ServiceProviderBuilder()
|
||||
Services = new NServiceProvider()
|
||||
.AddManual<IBotCredentials>(Credentials)
|
||||
.AddManual(_db)
|
||||
.AddManual(Client)
|
||||
.AddManual(CommandService)
|
||||
.AddManual(botConfigProvider)
|
||||
//.AddManual<ILocalization>(localization)
|
||||
.AddManual<IEnumerable<GuildConfig>>(AllGuildConfigs) //todo wrap this
|
||||
.AddManual<NadekoBot>(this)
|
||||
.AddManual<IUnitOfWork>(uow)
|
||||
.AddManual<IDataCache>(new RedisCache(Client.CurrentUser.Id))
|
||||
.LoadFrom(Assembly.GetEntryAssembly())
|
||||
.Build();
|
||||
.AddManual<IDataCache>(new RedisCache(Client.CurrentUser.Id));
|
||||
Services.LoadFrom(Assembly.GetAssembly(typeof(CommandHandler)));
|
||||
|
||||
var commandHandler = Services.GetService<CommandHandler>();
|
||||
commandHandler.AddServices(Services);
|
||||
@ -163,12 +150,13 @@ namespace NadekoBot
|
||||
CommandService.AddTypeReader<PermissionAction>(new PermissionActionTypeReader());
|
||||
CommandService.AddTypeReader<CommandInfo>(new CommandTypeReader());
|
||||
//todo module dependency
|
||||
//CommandService.AddTypeReader<CommandOrCrInfo>(new CommandOrCrTypeReader());
|
||||
CommandService.AddTypeReader<CommandOrCrInfo>(new CommandOrCrTypeReader());
|
||||
CommandService.AddTypeReader<ModuleInfo>(new ModuleTypeReader(CommandService));
|
||||
CommandService.AddTypeReader<ModuleOrCrInfo>(new ModuleOrCrTypeReader(CommandService));
|
||||
CommandService.AddTypeReader<IGuild>(new GuildTypeReader(Client));
|
||||
//CommandService.AddTypeReader<GuildDateTime>(new GuildDateTimeTypeReader());
|
||||
}
|
||||
Services.Unload(typeof(IUnitOfWork)); // unload it after the startup
|
||||
}
|
||||
|
||||
private async Task LoginAsync(string token)
|
||||
@ -193,7 +181,7 @@ namespace NadekoBot
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
||||
|
||||
}
|
||||
});
|
||||
return Task.CompletedTask;
|
||||
@ -225,8 +213,8 @@ namespace NadekoBot
|
||||
|
||||
public async Task RunAsync(params string[] args)
|
||||
{
|
||||
if(Client.ShardId == 0)
|
||||
_log.Info("Starting NadekoBot v" + StatsService.BotVersion);
|
||||
if (Client.ShardId == 0)
|
||||
_log.Info("Starting NadekoBot v" + StatsService.BotVersion);
|
||||
|
||||
var sw = Stopwatch.StartNew();
|
||||
|
||||
@ -255,7 +243,7 @@ namespace NadekoBot
|
||||
#endif
|
||||
//unload modules which are not available on the public bot
|
||||
|
||||
if(isPublicNadeko)
|
||||
if (isPublicNadeko)
|
||||
CommandService
|
||||
.Modules
|
||||
.ToArray()
|
||||
@ -372,5 +360,80 @@ namespace NadekoBot
|
||||
var sub = Services.GetService<IDataCache>().Redis.GetSubscriber();
|
||||
return sub.PublishAsync(Client.CurrentUser.Id + "_status.game_set", JsonConvert.SerializeObject(obj));
|
||||
}
|
||||
|
||||
private readonly Dictionary<string, IEnumerable<ModuleInfo>> _packageModules = new Dictionary<string, IEnumerable<ModuleInfo>>();
|
||||
private readonly Dictionary<string, IEnumerable<Type>> _packageTypes = new Dictionary<string, IEnumerable<Type>>();
|
||||
private readonly SemaphoreSlim _packageLocker = new SemaphoreSlim(1, 1);
|
||||
|
||||
/// <summary>
|
||||
/// Unloads a package
|
||||
/// </summary>
|
||||
/// <param name="name">Package name. Case sensitive.</param>
|
||||
/// <returns>Whether the unload is successful.</returns>
|
||||
public async Task<bool> UnloadPackage(string name)
|
||||
{
|
||||
await _packageLocker.WaitAsync().ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
if (!_packageModules.TryGetValue(name, out var modules))
|
||||
return false;
|
||||
|
||||
var i = 0;
|
||||
foreach (var m in modules)
|
||||
{
|
||||
await CommandService.RemoveModuleAsync(m).ConfigureAwait(false);
|
||||
i++;
|
||||
}
|
||||
_log.Info("Unloaded {0} modules.", i);
|
||||
|
||||
if (_packageTypes.TryGetValue(name, out var types))
|
||||
{
|
||||
i = 0;
|
||||
foreach (var t in types)
|
||||
{
|
||||
var obj = Services.Unload(t);
|
||||
if (obj is IUnloadableService s)
|
||||
await s.Unload().ConfigureAwait(false);
|
||||
i++;
|
||||
}
|
||||
|
||||
_log.Info("Unloaded {0} types.", i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
finally
|
||||
{
|
||||
_packageLocker.Release();
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Loads a package
|
||||
/// </summary>
|
||||
/// <param name="name">Name of the package to load. Case sensitive.</param>
|
||||
/// <returns>Whether the load is successful.</returns>
|
||||
public async Task<bool> LoadPackage(string name)
|
||||
{
|
||||
await _packageLocker.WaitAsync().ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
if (_packageModules.ContainsKey(name))
|
||||
return false;
|
||||
|
||||
var package = Assembly.LoadFile(Path.Combine(AppContext.BaseDirectory,
|
||||
"modules",
|
||||
$"NadekoBot.Modules.{name}",
|
||||
$"NadekoBot.Modules.{name}.dll"));
|
||||
var types = Services.LoadFrom(package);
|
||||
var added = await CommandService.AddModulesAsync(package).ConfigureAwait(false);
|
||||
_log.Info("Loaded {0} modules and {1} types.", added.Count(), types.Count());
|
||||
_packageModules.Add(name, added);
|
||||
_packageTypes.Add(name, types);
|
||||
return true;
|
||||
}
|
||||
finally
|
||||
{
|
||||
_packageLocker.Release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Reflection;
|
||||
@ -17,57 +16,83 @@ namespace NadekoBot.Services
|
||||
public interface INServiceProvider : IServiceProvider, IEnumerable<object>
|
||||
{
|
||||
T GetService<T>();
|
||||
IEnumerable<Type> LoadFrom(Assembly assembly);
|
||||
INServiceProvider AddManual<T>(T obj);
|
||||
object Unload(Type t);
|
||||
}
|
||||
|
||||
public class NServiceProvider : INServiceProvider
|
||||
{
|
||||
public class ServiceProviderBuilder
|
||||
private readonly object _locker = new object();
|
||||
private readonly Logger _log;
|
||||
|
||||
public readonly Dictionary<Type, object> _services = new Dictionary<Type, object>();
|
||||
public IReadOnlyDictionary<Type, object> Services => _services;
|
||||
|
||||
public NServiceProvider()
|
||||
{
|
||||
private ConcurrentDictionary<Type, object> _dict = new ConcurrentDictionary<Type, object>();
|
||||
private readonly Logger _log;
|
||||
_log = LogManager.GetCurrentClassLogger();
|
||||
}
|
||||
|
||||
public ServiceProviderBuilder()
|
||||
public T GetService<T>()
|
||||
{
|
||||
return (T)((IServiceProvider)(this)).GetService(typeof(T));
|
||||
}
|
||||
|
||||
object IServiceProvider.GetService(Type serviceType)
|
||||
{
|
||||
_services.TryGetValue(serviceType, out var toReturn);
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
public INServiceProvider AddManual<T>(T obj)
|
||||
{
|
||||
lock (_locker)
|
||||
{
|
||||
_log = LogManager.GetCurrentClassLogger();
|
||||
_services.TryAdd(typeof(T), obj);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public ServiceProviderBuilder AddManual<T>(T obj)
|
||||
public IEnumerable<Type> LoadFrom(Assembly assembly)
|
||||
{
|
||||
List<Type> addedTypes = new List<Type>();
|
||||
|
||||
Type[] allTypes;
|
||||
try
|
||||
{
|
||||
_dict.TryAdd(typeof(T), obj);
|
||||
return this;
|
||||
allTypes = assembly.GetTypes();
|
||||
}
|
||||
|
||||
public NServiceProvider Build()
|
||||
catch (ReflectionTypeLoadException ex)
|
||||
{
|
||||
return new NServiceProvider(_dict);
|
||||
Console.WriteLine(ex.LoaderExceptions[0]);
|
||||
return Enumerable.Empty<Type>();
|
||||
}
|
||||
|
||||
public ServiceProviderBuilder LoadFrom(Assembly assembly)
|
||||
{
|
||||
var allTypes = assembly.GetTypes();
|
||||
var services = new Queue<Type>(allTypes
|
||||
.Where(x => x.GetInterfaces().Contains(typeof(INService))
|
||||
&& !x.GetTypeInfo().IsInterface && !x.GetTypeInfo().IsAbstract
|
||||
|
||||
var services = new Queue<Type>(allTypes
|
||||
.Where(x => x.GetInterfaces().Contains(typeof(INService))
|
||||
&& !x.GetTypeInfo().IsInterface && !x.GetTypeInfo().IsAbstract
|
||||
#if GLOBAL_NADEKO
|
||||
&& x.GetTypeInfo().GetCustomAttribute<NoPublicBot>() == null
|
||||
&& x.GetTypeInfo().GetCustomAttribute<NoPublicBot>() == null
|
||||
#endif
|
||||
)
|
||||
.ToArray());
|
||||
.ToArray());
|
||||
|
||||
var interfaces = new HashSet<Type>(allTypes
|
||||
.Where(x => x.GetInterfaces().Contains(typeof(INService))
|
||||
&& x.GetTypeInfo().IsInterface));
|
||||
addedTypes.AddRange(services);
|
||||
|
||||
var alreadyFailed = new Dictionary<Type, int>();
|
||||
var interfaces = new HashSet<Type>(allTypes
|
||||
.Where(x => x.GetInterfaces().Contains(typeof(INService))
|
||||
&& x.GetTypeInfo().IsInterface));
|
||||
|
||||
var alreadyFailed = new Dictionary<Type, int>();
|
||||
lock (_locker)
|
||||
{
|
||||
var sw = Stopwatch.StartNew();
|
||||
var swInstance = new Stopwatch();
|
||||
while (services.Count > 0)
|
||||
{
|
||||
var type = services.Dequeue(); //get a type i need to make an instance of
|
||||
|
||||
if (_dict.TryGetValue(type, out _)) // if that type is already instantiated, skip
|
||||
if (_services.TryGetValue(type, out _)) // if that type is already instantiated, skip
|
||||
continue;
|
||||
|
||||
var ctor = type.GetConstructors()[0];
|
||||
@ -79,7 +104,7 @@ namespace NadekoBot.Services
|
||||
var args = new List<object>(argTypes.Length);
|
||||
foreach (var arg in argTypes) //get constructor arguments from the dictionary of already instantiated types
|
||||
{
|
||||
if (_dict.TryGetValue(arg, out var argObj)) //if i got current one, add it to the list of instances and move on
|
||||
if (_services.TryGetValue(arg, out var argObj)) //if i got current one, add it to the list of instances and move on
|
||||
args.Add(argObj);
|
||||
else //if i failed getting it, add it to the end, and break
|
||||
{
|
||||
@ -97,7 +122,7 @@ namespace NadekoBot.Services
|
||||
}
|
||||
if (args.Count != argTypes.Length)
|
||||
continue;
|
||||
// _log.Info("Loading " + type.Name);
|
||||
|
||||
swInstance.Restart();
|
||||
var instance = ctor.Invoke(args.ToArray());
|
||||
swInstance.Stop();
|
||||
@ -105,38 +130,34 @@ namespace NadekoBot.Services
|
||||
_log.Info($"{type.Name} took {swInstance.Elapsed.TotalSeconds:F2}s to load.");
|
||||
var interfaceType = interfaces.FirstOrDefault(x => instance.GetType().GetInterfaces().Contains(x));
|
||||
if (interfaceType != null)
|
||||
_dict.TryAdd(interfaceType, instance);
|
||||
{
|
||||
addedTypes.Add(interfaceType);
|
||||
_services.TryAdd(interfaceType, instance);
|
||||
}
|
||||
|
||||
_dict.TryAdd(type, instance);
|
||||
_services.TryAdd(type, instance);
|
||||
}
|
||||
sw.Stop();
|
||||
_log.Info($"All services loaded in {sw.Elapsed.TotalSeconds:F2}s");
|
||||
|
||||
return this;
|
||||
}
|
||||
return addedTypes;
|
||||
}
|
||||
|
||||
private readonly ImmutableDictionary<Type, object> _services;
|
||||
|
||||
private NServiceProvider() { }
|
||||
public NServiceProvider(IDictionary<Type, object> services)
|
||||
public object Unload(Type t)
|
||||
{
|
||||
this._services = services.ToImmutableDictionary();
|
||||
}
|
||||
|
||||
public T GetService<T>()
|
||||
{
|
||||
return (T)((IServiceProvider)(this)).GetService(typeof(T));
|
||||
}
|
||||
|
||||
object IServiceProvider.GetService(Type serviceType)
|
||||
{
|
||||
_services.TryGetValue(serviceType, out var toReturn);
|
||||
return toReturn;
|
||||
lock (_locker)
|
||||
{
|
||||
if (_services.TryGetValue(t, out var obj))
|
||||
{
|
||||
_services.Remove(t);
|
||||
return obj;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
IEnumerator IEnumerable.GetEnumerator() => _services.Values.GetEnumerator();
|
||||
|
||||
public IEnumerator<object> GetEnumerator() => _services.Values.GetEnumerator();
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user