Skip to content

Instantly share code, notes, and snippets.

@neuecc
Created April 8, 2021 21:52
Show Gist options
  • Save neuecc/058b5ef0c299505215f05ff2e7f283da to your computer and use it in GitHub Desktop.
Save neuecc/058b5ef0c299505215f05ff2e7f283da to your computer and use it in GitHub Desktop.
using Microsoft.Extensions.DependencyInjection;
using PropertyInjection;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Microsoft.Extensions.Hosting
{
public static class HostBuilderPropertyInjectionExtensions
{
public static IHostBuilder UsePropertyInjectionServiceProvider(this IHostBuilder hostBuilder)
{
return hostBuilder.UseServiceProviderFactory(context =>
{
var serviceProviderOptions = new ServiceProviderOptions();
var defaultFactory = new DefaultServiceProviderFactory(serviceProviderOptions);
return new PropertyInjectionServiceProviderFactory(defaultFactory);
});
}
public static IHostBuilder UsePropertyInjectionServiceProvider(this IHostBuilder hostBuilder, Action<ServiceProviderOptions> configure)
{
return hostBuilder.UseServiceProviderFactory(context =>
{
var serviceProviderOptions = new ServiceProviderOptions();
configure(serviceProviderOptions);
var defaultFactory = new DefaultServiceProviderFactory(serviceProviderOptions);
return new PropertyInjectionServiceProviderFactory(defaultFactory);
});
}
public static IHostBuilder UsePropertyInjectionServiceProvider(this IHostBuilder hostBuilder, Action<HostBuilderContext, ServiceProviderOptions> configure)
{
return hostBuilder.UseServiceProviderFactory(context =>
{
var serviceProviderOptions = new ServiceProviderOptions();
configure(context, serviceProviderOptions);
var defaultFactory = new DefaultServiceProviderFactory(serviceProviderOptions);
return new PropertyInjectionServiceProviderFactory(defaultFactory);
});
}
}
}
namespace PropertyInjection
{
public class PropertyInjectionServiceProviderFactory : IServiceProviderFactory<IServiceCollection>
{
readonly DefaultServiceProviderFactory defaultServiceProviderFactory;
public PropertyInjectionServiceProviderFactory(DefaultServiceProviderFactory defaultServiceProviderFactory)
{
this.defaultServiceProviderFactory = defaultServiceProviderFactory;
}
public IServiceCollection CreateBuilder(Microsoft.Extensions.DependencyInjection.IServiceCollection services)
{
return defaultServiceProviderFactory.CreateBuilder(services);
}
public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder)
{
var dict = new PropertyInjectorDictionary(containerBuilder);
var wrapper = new PropertyInjectionServiceProvider(dict);
var provider = containerBuilder.BuildServiceProvider();
wrapper.SetInnerProvider(provider);
return wrapper;
}
}
internal class PropertyInjectorDictionary
{
// nongeneric GetRequiredService(provider, type)
static readonly MethodInfo getRequiredServiceMethodInfo = typeof(Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions)
.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.InvokeMethod)
.First(x => x.Name == nameof(Microsoft.Extensions.DependencyInjection.ServiceProviderServiceExtensions.GetRequiredService) && x.GetParameters().Length == 2);
const float LoadFactor = 0.75f;
readonly HashTuple[][] table;
readonly int indexFor;
public PropertyInjectorDictionary(IServiceCollection values)
{
var initialCapacity = (int)((float)values.Count / LoadFactor);
// make power of 2(and use mask)
// see: Hashing https://en.wikipedia.org/wiki/Hash_table
var capacity = 1;
while (capacity < initialCapacity)
{
capacity <<= 1;
}
table = new HashTuple[(int)capacity][];
indexFor = table.Length - 1;
foreach (var item in values)
{
if (item.Lifetime != ServiceLifetime.Transient)
{
continue;
}
if (typeof(System.Collections.IEnumerable).IsAssignableFrom(item.ServiceType))
{
continue;
}
var serviceType = item.ServiceType;
var implType = item.ImplementationType ?? item.ServiceType;
if (!TryCreateInjector(implType, out var injector))
{
continue;
}
var hash = serviceType.GetHashCode();
var array = table[hash & indexFor];
if (array == null)
{
array = new HashTuple[1];
array[0] = new HashTuple(serviceType, injector);
}
else
{
var newArray = new HashTuple[array.Length + 1];
Array.Copy(array, newArray, array.Length);
array = newArray;
array[array.Length - 1] = new HashTuple(serviceType, injector);
}
table[hash & indexFor] = array;
}
}
static bool TryCreateInjector(Type implementationType, [MaybeNullWhen(false)] out Action<object, IServiceProvider> injector)
{
// allow both Microsoft.AspNetCore.Components.InjectAttribute and PropertyInjection.InjectAttribute.
var properties = implementationType.GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.FlattenHierarchy | BindingFlags.SetProperty)
.Where(x => x.GetCustomAttributes(true).Any(y => y.GetType().Name == "InjectAttribute"))
.Select(x => x.GetSetMethod(true))
.Where(x => x != null)
.ToArray();
var fields = implementationType.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.FlattenHierarchy | BindingFlags.GetField)
.Where(x => x.GetCustomAttributes(true).Any(y => y.GetType().Name == "InjectAttribute"))
.ToArray();
if (properties.Length == 0 && fields.Length == 0)
{
injector = default;
return false;
}
var args0 = Expression.Parameter(typeof(object), "instance");
var args1 = Expression.Parameter(typeof(IServiceProvider), "provider");
var local = Expression.Variable(implementationType, "obj");
var blockBody = new List<Expression>();
blockBody.Add(Expression.Assign(local, Expression.Convert(args0, implementationType)));
foreach (var item in fields)
{
var paramType = item.FieldType;
// instance.Foo = (Foo)provider.GetRequiredService(typeof(Foo))
blockBody.Add(Expression.Assign(Expression.Field(local, item), Expression.Convert(Expression.Call(getRequiredServiceMethodInfo, args1, Expression.Constant(paramType, typeof(Type))), paramType)));
}
foreach (var item in properties)
{
var paramType = item!.GetParameters()[0].ParameterType;
// instance.Foo = (Foo)provider.GetRequiredService(typeof(Foo))
blockBody.Add(Expression.Call(local, item!, Expression.Convert(Expression.Call(getRequiredServiceMethodInfo, args1, Expression.Constant(paramType, typeof(Type))), paramType)));
}
var bodyExpression = Expression.Block(
new[] { local },
blockBody);
var lambda = Expression.Lambda<Action<object, IServiceProvider>>(bodyExpression, args0, args1);
injector = lambda.Compile();
return true;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool TryGetInjector(Type serviceType, [MaybeNullWhen(false)] out Action<object, IServiceProvider> injector)
{
var hashCode = serviceType.GetHashCode();
var buckets = table[hashCode & indexFor];
if (buckets == null) goto END;
for (int i = 0; i < buckets.Length; i++)
{
if (buckets[i].Type == serviceType)
{
injector = buckets[i].Value;
return true;
}
}
END:
injector = null;
return false;
}
[StructLayout(LayoutKind.Auto)]
internal readonly struct HashTuple
{
public readonly Type Type;
public readonly Action<object, IServiceProvider> Value;
public HashTuple(Type type, Action<object, IServiceProvider> value)
{
Type = type;
Value = value;
}
}
}
internal class PropertyInjectionServiceProvider : IServiceProvider
{
IServiceProvider innerProvider;
readonly PropertyInjectorDictionary propertyInjectionDictionary;
public PropertyInjectionServiceProvider(PropertyInjectorDictionary propertyInjectionDictionary)
{
this.innerProvider = null!;
this.propertyInjectionDictionary = propertyInjectionDictionary;
}
public void SetInnerProvider(IServiceProvider serviceProvider)
{
this.innerProvider = serviceProvider;
}
public object? GetService(Type serviceType)
{
// TODO: can not inject IServiceProvider iteself so broken.
var service = innerProvider.GetService(serviceType);
if (service == null)
{
return null;
}
if (propertyInjectionDictionary.TryGetInjector(serviceType, out var injector))
{
injector(service, innerProvider);
}
return service;
}
}
[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
public sealed class InjectAttribute : Attribute
{
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment