Created
October 10, 2016 21:31
-
-
Save sixten/2bd8ec7a6ad868cd709e518b000bcf72 to your computer and use it in GitHub Desktop.
Modified fake DbSet code, allowing expression visitors to be plugged into LINQ execution.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Collections.ObjectModel; | |
using System.Data.Entity; | |
using System.Data.Entity.Infrastructure; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Threading; | |
using System.Threading.Tasks; | |
namespace Fakes | |
{ | |
public class FakeDbSet<TEntity> : IDbSet<TEntity>, IDbAsyncEnumerable<TEntity> | |
where TEntity : class | |
{ | |
ObservableCollection<TEntity> _data; | |
IQueryable _query; | |
Func<object[], Func<TEntity, bool>> _findPredicateFactory; | |
public ICollection<ExpressionVisitor> ExpressionVisitors { get; private set; } | |
public FakeDbSet() : this(findPredicateFactory: null) | |
{ | |
} | |
public FakeDbSet( Func<object[], Func<TEntity, bool>> findPredicateFactory ) | |
{ | |
_data = new ObservableCollection<TEntity>(); | |
_query = _data.AsQueryable(); | |
_findPredicateFactory = findPredicateFactory; | |
ExpressionVisitors = new List<ExpressionVisitor>(); | |
} | |
public TEntity Add( TEntity item ) | |
{ | |
_data.Add(item); | |
return item; | |
} | |
public TEntity Remove( TEntity item ) | |
{ | |
_data.Remove(item); | |
return item; | |
} | |
public TEntity Attach( TEntity item ) | |
{ | |
_data.Add(item); | |
return item; | |
} | |
public TEntity Create() | |
{ | |
return Activator.CreateInstance<TEntity>(); | |
} | |
public TDerivedEntity Create<TDerivedEntity>() | |
where TDerivedEntity : class, TEntity | |
{ | |
return Activator.CreateInstance<TDerivedEntity>(); | |
} | |
public ObservableCollection<TEntity> Local | |
{ | |
get { return _data; } | |
} | |
public TEntity Find( params object[] keyValues ) | |
{ | |
if( null != _findPredicateFactory ) { | |
return _data.FirstOrDefault(_findPredicateFactory(keyValues)); | |
} | |
throw new InvalidOperationException("Must configure this TestDbSet with a predicate factory to enable Find()."); | |
} | |
Type IQueryable.ElementType | |
{ | |
get { return _query.ElementType; } | |
} | |
Expression IQueryable.Expression | |
{ | |
get { return _query.Expression; } | |
} | |
IQueryProvider IQueryable.Provider | |
{ | |
get { return new FakeAsyncQueryProvider<TEntity>(_query.Provider, ExpressionVisitors); } | |
} | |
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() | |
{ | |
return _data.GetEnumerator(); | |
} | |
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator() | |
{ | |
return _data.GetEnumerator(); | |
} | |
public IDbAsyncEnumerator<TEntity> GetAsyncEnumerator() | |
{ | |
return new TestDbAsyncEnumerator<TEntity>(_data.GetEnumerator()); | |
} | |
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator() | |
{ | |
return GetAsyncEnumerator(); | |
} | |
} | |
internal class FakeAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider | |
{ | |
private readonly IQueryProvider _inner; | |
private readonly IEnumerable<ExpressionVisitor> _visitors; | |
internal FakeAsyncQueryProvider( IQueryProvider inner ) | |
: this(inner, expressionProcessors: new ExpressionVisitor[0]) | |
{ | |
} | |
internal FakeAsyncQueryProvider( IQueryProvider inner, IEnumerable<ExpressionVisitor> expressionProcessors ) | |
{ | |
_inner = inner; | |
_visitors = expressionProcessors.ToList(); | |
} | |
private Expression ProcessExpression( Expression expression ) | |
{ | |
Expression fixedExp = expression; | |
foreach( var visitor in _visitors ) { | |
fixedExp = visitor.Visit(fixedExp); | |
} | |
return fixedExp; | |
} | |
public IQueryable CreateQuery( Expression expression ) | |
{ | |
var newExpression = ProcessExpression(expression); | |
return new TestDbAsyncEnumerable<TEntity>(newExpression, _visitors); | |
} | |
public IQueryable<TElement> CreateQuery<TElement>( Expression expression ) | |
{ | |
var newExpression = ProcessExpression(expression); | |
return new TestDbAsyncEnumerable<TElement>(newExpression, _visitors); | |
} | |
public object Execute( Expression expression ) | |
{ | |
return _inner.Execute(expression); | |
} | |
public TResult Execute<TResult>( Expression expression ) | |
{ | |
return _inner.Execute<TResult>(expression); | |
} | |
public Task<object> ExecuteAsync( Expression expression, CancellationToken cancellationToken ) | |
{ | |
return Task.FromResult(Execute(expression)); | |
} | |
public Task<TResult> ExecuteAsync<TResult>( Expression expression, CancellationToken cancellationToken ) | |
{ | |
return Task.FromResult(Execute<TResult>(expression)); | |
} | |
} | |
internal class TestDbAsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T> | |
{ | |
private readonly IEnumerable<ExpressionVisitor> _visitors; | |
public TestDbAsyncEnumerable( IEnumerable<T> enumerable ) | |
: base(enumerable) | |
{ | |
_visitors = new ExpressionVisitor[0]; | |
} | |
public TestDbAsyncEnumerable( Expression expression, IEnumerable<ExpressionVisitor> visitors ) | |
: base(expression) | |
{ | |
_visitors = visitors; | |
} | |
public IDbAsyncEnumerator<T> GetAsyncEnumerator() | |
{ | |
return new TestDbAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator()); | |
} | |
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator() | |
{ | |
return GetAsyncEnumerator(); | |
} | |
IQueryProvider IQueryable.Provider | |
{ | |
get { return new FakeAsyncQueryProvider<T>(this, _visitors); } | |
} | |
} | |
internal class TestDbAsyncEnumerator<T> : IDbAsyncEnumerator<T> | |
{ | |
private readonly IEnumerator<T> _inner; | |
public TestDbAsyncEnumerator( IEnumerator<T> inner ) | |
{ | |
_inner = inner; | |
} | |
public void Dispose() | |
{ | |
_inner.Dispose(); | |
} | |
public Task<bool> MoveNextAsync( CancellationToken cancellationToken ) | |
{ | |
return Task.FromResult(_inner.MoveNext()); | |
} | |
public T Current | |
{ | |
get { return _inner.Current; } | |
} | |
object IDbAsyncEnumerator.Current | |
{ | |
get { return Current; } | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
var foosSet = new FakeDbSet<Foo>(keyValues => { | |
if( 1 == keyValues.Length && keyValues[0] is long ) { | |
var key = (long)keyValues[0]; | |
return f => f.Id == key; | |
} | |
else { | |
return f => false; | |
} | |
}); | |
foosSet.ExpressionVisitors.Add(new VersionOrderingVisitor()); | |
fakeContext.Foos = foosSet; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Reflection; | |
namespace Fakes | |
{ | |
internal class VersionOrderingVisitor : System.Linq.Expressions.ExpressionVisitor | |
{ | |
private readonly MethodInfo _orderBy2Method; | |
private readonly MethodInfo _orderBy3Method; | |
private readonly MethodInfo _orderByDesc2Method; | |
private readonly MethodInfo _orderByDesc3Method; | |
public VersionOrderingVisitor() | |
{ | |
var enumerableType = typeof(System.Linq.Queryable); | |
var orderByMethods = enumerableType.GetMethods() | |
.Where(m => m.Name == "OrderBy" && m.IsGenericMethodDefinition); | |
_orderBy2Method = orderByMethods.Where(m => 2 == m.GetParameters().Count()).Single() | |
.MakeGenericMethod(typeof(Foo), typeof(byte[])); | |
_orderBy3Method = orderByMethods.Where(m => 3 == m.GetParameters().Count()).Single() | |
.MakeGenericMethod(typeof(Foo), typeof(byte[])); | |
var orderByDescMethods = enumerableType.GetMethods() | |
.Where(m => m.Name == "OrderByDescending" && m.IsGenericMethodDefinition); | |
_orderByDesc2Method = orderByDescMethods.Where(m => 2 == m.GetParameters().Count()).Single() | |
.MakeGenericMethod(typeof(Foo), typeof(byte[])); | |
_orderByDesc3Method = orderByDescMethods.Where(m => 3 == m.GetParameters().Count()).Single() | |
.MakeGenericMethod(typeof(Foo), typeof(byte[])); | |
} | |
protected override Expression VisitMethodCall( MethodCallExpression node ) | |
{ | |
if( node.Method.Equals(_orderBy2Method) ) { | |
var comparerExpression = Expression.Constant(new ByteArrayComparer()); | |
var arguments = Visit(node.Arguments).AsEnumerable().Concat(new[] { comparerExpression }); | |
var newNode = Expression.Call(node.Object, _orderBy3Method, arguments); | |
return newNode; | |
} | |
else if( node.Method.Equals(_orderByDesc2Method) ) { | |
var comparerExpression = Expression.Constant(new ByteArrayComparer()); | |
var arguments = Visit(node.Arguments).AsEnumerable().Concat(new[] { comparerExpression }); | |
var newNode = Expression.Call(node.Object, _orderByDesc3Method, arguments); | |
return newNode; | |
} | |
return base.VisitMethodCall(node); | |
} | |
private class ByteArrayComparer : IComparer<byte[]> | |
{ | |
public int Compare( byte[] b1, byte[] b2 ) | |
{ | |
return ((IStructuralComparable)b1).CompareTo(b2, StructuralComparisons.StructuralComparer); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment