Created
November 22, 2014 00:07
-
-
Save cdemi/59066d217a257eb5b353 to your computer and use it in GitHub Desktop.
Converts IEnumerable<T> to IDataReader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq.Expressions; | |
using System.Reflection; | |
namespace System.Data | |
{ | |
/// <summary> | |
/// IDataReader that can be used for "reading" an IEnumerable<T> collection | |
/// </summary> | |
public class EnumerableDataReader<T> : IDataReader | |
{ | |
private readonly List<BaseField> m_Fields = new List<BaseField>(); | |
private T m_Current; | |
private IEnumerator<T> m_Enumerator; | |
private bool m_EnumeratorState; | |
#region IDisposable Members | |
public void Dispose() | |
{ | |
if (m_Enumerator != null) | |
{ | |
m_Enumerator.Dispose(); | |
m_Enumerator = null; | |
m_Current = default(T); | |
m_EnumeratorState = false; | |
} | |
m_Closed = true; | |
} | |
#endregion | |
#region IDataReader Members | |
private bool m_Closed; | |
public void Close() | |
{ | |
m_Closed = true; | |
} | |
public int Depth | |
{ | |
get { return 0; } | |
} | |
public DataTable GetSchemaTable() | |
{ | |
var dt = new DataTable(); | |
foreach (BaseField field in m_Fields) | |
{ | |
dt.Columns.Add(new DataColumn(field.Name, field.Type)); | |
} | |
return dt; | |
} | |
public bool IsClosed | |
{ | |
get { return m_Closed; } | |
} | |
public bool NextResult() | |
{ | |
return false; | |
} | |
public bool Read() | |
{ | |
if (IsClosed) | |
throw new InvalidOperationException("DataReader is closed"); | |
m_EnumeratorState = m_Enumerator.MoveNext(); | |
m_Current = m_EnumeratorState ? m_Enumerator.Current : default(T); | |
return m_EnumeratorState; | |
} | |
public int RecordsAffected | |
{ | |
get { return -1; } | |
} | |
#endregion | |
#region IDataRecord Members | |
public int FieldCount | |
{ | |
get { return m_Fields.Count; } | |
} | |
public Type GetFieldType(int i) | |
{ | |
if (i < 0 || i >= m_Fields.Count) | |
throw new IndexOutOfRangeException(); | |
return m_Fields[i].Type; | |
} | |
public string GetDataTypeName(int i) | |
{ | |
return GetFieldType(i).Name; | |
} | |
public string GetName(int i) | |
{ | |
if (i < 0 || i >= m_Fields.Count) | |
throw new IndexOutOfRangeException(); | |
return m_Fields[i].Name; | |
} | |
public int GetOrdinal(string name) | |
{ | |
for (int i = 0; i < m_Fields.Count; i++) | |
if (m_Fields[i].Name == name) | |
return i; | |
throw new IndexOutOfRangeException("name"); | |
} | |
public bool IsDBNull(int i) | |
{ | |
return GetValue(i) == null; | |
} | |
public object this[string name] | |
{ | |
get { return GetValue(GetOrdinal(name)); } | |
} | |
public object this[int i] | |
{ | |
get { return GetValue(i); } | |
} | |
public object GetValue(int i) | |
{ | |
if (IsClosed || !m_EnumeratorState) | |
throw new InvalidOperationException("DataReader is closed or has reached the end of the enumerator"); | |
if (i < 0 || i >= m_Fields.Count) | |
throw new IndexOutOfRangeException(); | |
return m_Fields[i].GetValue(m_Current); | |
} | |
public int GetValues(object[] values) | |
{ | |
int length = Math.Min(m_Fields.Count, values.Length); | |
for (int i = 0; i < length; i++) | |
values[i] = GetValue(i); | |
return length; | |
} | |
public bool GetBoolean(int i) | |
{ | |
return (bool) GetValue(i); | |
} | |
public byte GetByte(int i) | |
{ | |
return (byte) GetValue(i); | |
} | |
public char GetChar(int i) | |
{ | |
return (char) GetValue(i); | |
} | |
public DateTime GetDateTime(int i) | |
{ | |
return (DateTime) GetValue(i); | |
} | |
public decimal GetDecimal(int i) | |
{ | |
return (decimal) GetValue(i); | |
} | |
public double GetDouble(int i) | |
{ | |
return (double) GetValue(i); | |
} | |
public float GetFloat(int i) | |
{ | |
return (float) GetValue(i); | |
} | |
public Guid GetGuid(int i) | |
{ | |
return (Guid) GetValue(i); | |
} | |
public short GetInt16(int i) | |
{ | |
return (short) GetValue(i); | |
} | |
public int GetInt32(int i) | |
{ | |
return (int) GetValue(i); | |
} | |
public long GetInt64(int i) | |
{ | |
return (long) GetValue(i); | |
} | |
public string GetString(int i) | |
{ | |
return (string) GetValue(i); | |
} | |
public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) | |
{ | |
throw new NotSupportedException(); | |
} | |
public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) | |
{ | |
throw new NotSupportedException(); | |
} | |
public IDataReader GetData(int i) | |
{ | |
throw new NotSupportedException(); | |
} | |
#endregion | |
#region Helper Classes | |
private abstract class BaseField | |
{ | |
private static Dictionary<string, Func<T, object>> m_GetterDictionary = | |
new Dictionary<string, Func<T, object>>(); | |
public abstract Type Type { get; } | |
public abstract string Name { get; } | |
public abstract object GetValue(T instance); | |
protected static void AddGetter(Type classType, string fieldName, Func<T, object> getter) | |
{ | |
m_GetterDictionary.Add(string.Concat(classType.FullName, fieldName), getter); | |
} | |
protected static Func<T, object> GetGetter(Type classType, string fieldName) | |
{ | |
Func<T, object> getter = null; | |
if (m_GetterDictionary.TryGetValue(string.Concat(classType.FullName, fieldName), out getter)) | |
return getter; | |
return null; | |
} | |
} | |
private class Field : BaseField | |
{ | |
private readonly Func<T, object> m_DynamicGetter; | |
private readonly FieldInfo m_Info; | |
public Field(FieldInfo info) | |
{ | |
m_Info = info; | |
m_DynamicGetter = CreateGetMethod(info); | |
} | |
public override Type Type | |
{ | |
get { return m_Info.FieldType; } | |
} | |
public override string Name | |
{ | |
get { return m_Info.Name; } | |
} | |
public override object GetValue(T instance) | |
{ | |
//return m_Info.GetValue(instance); // Reflection is slow | |
return m_DynamicGetter(instance); | |
} | |
// Create dynamic method for faster access instead via reflection | |
private Func<T, object> CreateGetMethod(FieldInfo fieldInfo) | |
{ | |
Type classType = typeof (T); | |
Func<T, object> dynamicGetter = GetGetter(classType, fieldInfo.Name); | |
if (dynamicGetter == null) | |
{ | |
ParameterExpression instance = Expression.Parameter(classType); | |
MemberExpression property = Expression.Field(instance, fieldInfo); | |
UnaryExpression convert = Expression.Convert(property, typeof (object)); | |
dynamicGetter = (Func<T, object>) Expression.Lambda(convert, instance).Compile(); | |
AddGetter(classType, fieldInfo.Name, dynamicGetter); | |
} | |
return dynamicGetter; | |
} | |
} | |
private class Property : BaseField | |
{ | |
private readonly Func<T, object> m_DynamicGetter; | |
private readonly PropertyInfo m_Info; | |
public Property(PropertyInfo info) | |
{ | |
m_Info = info; | |
m_DynamicGetter = CreateGetMethod(info); | |
} | |
public override Type Type | |
{ | |
get { return m_Info.PropertyType; } | |
} | |
public override string Name | |
{ | |
get { return m_Info.Name; } | |
} | |
public override object GetValue(T instance) | |
{ | |
//return m_Info.GetValue(instance, null); // Reflection is slow | |
return m_DynamicGetter(instance); | |
} | |
// Create dynamic method for faster access instead via reflection | |
private Func<T, object> CreateGetMethod(PropertyInfo propertyInfo) | |
{ | |
Type classType = typeof (T); | |
Func<T, object> dynamicGetter = GetGetter(classType, propertyInfo.Name); | |
if (dynamicGetter == null) | |
{ | |
ParameterExpression instance = Expression.Parameter(classType); | |
MemberExpression property = Expression.Property(instance, propertyInfo); | |
UnaryExpression convert = Expression.Convert(property, typeof (object)); | |
dynamicGetter = (Func<T, object>) Expression.Lambda(convert, instance).Compile(); | |
AddGetter(classType, propertyInfo.Name, dynamicGetter); | |
} | |
return dynamicGetter; | |
} | |
} | |
private class Self : BaseField | |
{ | |
private readonly Type m_Type; | |
public Self() | |
{ | |
m_Type = typeof (T); | |
} | |
public override Type Type | |
{ | |
get { return m_Type; } | |
} | |
public override string Name | |
{ | |
get { return string.Empty; } | |
} | |
public override object GetValue(T instance) | |
{ | |
return instance; | |
} | |
} | |
#endregion | |
/// <summary> | |
/// Constructor | |
/// </summary> | |
/// <param name="collection">The collection to be read</param> | |
/// <param name="fields"> | |
/// The list of public field/properties to read from each T (in order), OR if no fields are given only | |
/// one field will be available: T itself | |
/// </param> | |
public EnumerableDataReader(IEnumerable<T> collection, params string[] fields) | |
{ | |
if (collection == null) | |
throw new ArgumentNullException("collection"); | |
m_Enumerator = collection.GetEnumerator(); | |
if (m_Enumerator == null) | |
throw new NullReferenceException("collection does not implement GetEnumerator"); | |
SetFields(fields); | |
} | |
private void SetFields(ICollection<string> fields) | |
{ | |
if (fields.Count > 0) | |
{ | |
Type type = typeof (T); | |
foreach (string field in fields) | |
{ | |
PropertyInfo pInfo = type.GetProperty(field); | |
if (pInfo != null) | |
m_Fields.Add(new Property(pInfo)); | |
else | |
{ | |
FieldInfo fInfo = type.GetField(field); | |
if (fInfo != null) | |
m_Fields.Add(new Field(fInfo)); | |
else | |
throw new NullReferenceException( | |
string.Format( | |
"EnumerableDataReader<T>: Missing property or field '{0}' in Type '{1}'.", field, | |
type.Name)); | |
} | |
} | |
} | |
else | |
m_Fields.Add(new Self()); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment