Skip to content

Instantly share code, notes, and snippets.

@sixten
Created October 10, 2016 21:31
Show Gist options
  • Save sixten/2bd8ec7a6ad868cd709e518b000bcf72 to your computer and use it in GitHub Desktop.
Save sixten/2bd8ec7a6ad868cd709e518b000bcf72 to your computer and use it in GitHub Desktop.
Modified fake DbSet code, allowing expression visitors to be plugged into LINQ execution.
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
namespace Fakes
{
public class FakeDbSet<TEntity> : IDbSet<TEntity>, IDbAsyncEnumerable<TEntity>
where TEntity : class
{
ObservableCollection<TEntity> _data;
IQueryable _query;
Func<object[], Func<TEntity, bool>> _findPredicateFactory;
public ICollection<ExpressionVisitor> ExpressionVisitors { get; private set; }
public FakeDbSet() : this(findPredicateFactory: null)
{
}
public FakeDbSet( Func<object[], Func<TEntity, bool>> findPredicateFactory )
{
_data = new ObservableCollection<TEntity>();
_query = _data.AsQueryable();
_findPredicateFactory = findPredicateFactory;
ExpressionVisitors = new List<ExpressionVisitor>();
}
public TEntity Add( TEntity item )
{
_data.Add(item);
return item;
}
public TEntity Remove( TEntity item )
{
_data.Remove(item);
return item;
}
public TEntity Attach( TEntity item )
{
_data.Add(item);
return item;
}
public TEntity Create()
{
return Activator.CreateInstance<TEntity>();
}
public TDerivedEntity Create<TDerivedEntity>()
where TDerivedEntity : class, TEntity
{
return Activator.CreateInstance<TDerivedEntity>();
}
public ObservableCollection<TEntity> Local
{
get { return _data; }
}
public TEntity Find( params object[] keyValues )
{
if( null != _findPredicateFactory ) {
return _data.FirstOrDefault(_findPredicateFactory(keyValues));
}
throw new InvalidOperationException("Must configure this TestDbSet with a predicate factory to enable Find().");
}
Type IQueryable.ElementType
{
get { return _query.ElementType; }
}
Expression IQueryable.Expression
{
get { return _query.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return new FakeAsyncQueryProvider<TEntity>(_query.Provider, ExpressionVisitors); }
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return _data.GetEnumerator();
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return _data.GetEnumerator();
}
public IDbAsyncEnumerator<TEntity> GetAsyncEnumerator()
{
return new TestDbAsyncEnumerator<TEntity>(_data.GetEnumerator());
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
}
internal class FakeAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
{
private readonly IQueryProvider _inner;
private readonly IEnumerable<ExpressionVisitor> _visitors;
internal FakeAsyncQueryProvider( IQueryProvider inner )
: this(inner, expressionProcessors: new ExpressionVisitor[0])
{
}
internal FakeAsyncQueryProvider( IQueryProvider inner, IEnumerable<ExpressionVisitor> expressionProcessors )
{
_inner = inner;
_visitors = expressionProcessors.ToList();
}
private Expression ProcessExpression( Expression expression )
{
Expression fixedExp = expression;
foreach( var visitor in _visitors ) {
fixedExp = visitor.Visit(fixedExp);
}
return fixedExp;
}
public IQueryable CreateQuery( Expression expression )
{
var newExpression = ProcessExpression(expression);
return new TestDbAsyncEnumerable<TEntity>(newExpression, _visitors);
}
public IQueryable<TElement> CreateQuery<TElement>( Expression expression )
{
var newExpression = ProcessExpression(expression);
return new TestDbAsyncEnumerable<TElement>(newExpression, _visitors);
}
public object Execute( Expression expression )
{
return _inner.Execute(expression);
}
public TResult Execute<TResult>( Expression expression )
{
return _inner.Execute<TResult>(expression);
}
public Task<object> ExecuteAsync( Expression expression, CancellationToken cancellationToken )
{
return Task.FromResult(Execute(expression));
}
public Task<TResult> ExecuteAsync<TResult>( Expression expression, CancellationToken cancellationToken )
{
return Task.FromResult(Execute<TResult>(expression));
}
}
internal class TestDbAsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T>
{
private readonly IEnumerable<ExpressionVisitor> _visitors;
public TestDbAsyncEnumerable( IEnumerable<T> enumerable )
: base(enumerable)
{
_visitors = new ExpressionVisitor[0];
}
public TestDbAsyncEnumerable( Expression expression, IEnumerable<ExpressionVisitor> visitors )
: base(expression)
{
_visitors = visitors;
}
public IDbAsyncEnumerator<T> GetAsyncEnumerator()
{
return new TestDbAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
IQueryProvider IQueryable.Provider
{
get { return new FakeAsyncQueryProvider<T>(this, _visitors); }
}
}
internal class TestDbAsyncEnumerator<T> : IDbAsyncEnumerator<T>
{
private readonly IEnumerator<T> _inner;
public TestDbAsyncEnumerator( IEnumerator<T> inner )
{
_inner = inner;
}
public void Dispose()
{
_inner.Dispose();
}
public Task<bool> MoveNextAsync( CancellationToken cancellationToken )
{
return Task.FromResult(_inner.MoveNext());
}
public T Current
{
get { return _inner.Current; }
}
object IDbAsyncEnumerator.Current
{
get { return Current; }
}
}
}
var foosSet = new FakeDbSet<Foo>(keyValues => {
if( 1 == keyValues.Length && keyValues[0] is long ) {
var key = (long)keyValues[0];
return f => f.Id == key;
}
else {
return f => false;
}
});
foosSet.ExpressionVisitors.Add(new VersionOrderingVisitor());
fakeContext.Foos = foosSet;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
namespace Fakes
{
internal class VersionOrderingVisitor : System.Linq.Expressions.ExpressionVisitor
{
private readonly MethodInfo _orderBy2Method;
private readonly MethodInfo _orderBy3Method;
private readonly MethodInfo _orderByDesc2Method;
private readonly MethodInfo _orderByDesc3Method;
public VersionOrderingVisitor()
{
var enumerableType = typeof(System.Linq.Queryable);
var orderByMethods = enumerableType.GetMethods()
.Where(m => m.Name == "OrderBy" && m.IsGenericMethodDefinition);
_orderBy2Method = orderByMethods.Where(m => 2 == m.GetParameters().Count()).Single()
.MakeGenericMethod(typeof(Foo), typeof(byte[]));
_orderBy3Method = orderByMethods.Where(m => 3 == m.GetParameters().Count()).Single()
.MakeGenericMethod(typeof(Foo), typeof(byte[]));
var orderByDescMethods = enumerableType.GetMethods()
.Where(m => m.Name == "OrderByDescending" && m.IsGenericMethodDefinition);
_orderByDesc2Method = orderByDescMethods.Where(m => 2 == m.GetParameters().Count()).Single()
.MakeGenericMethod(typeof(Foo), typeof(byte[]));
_orderByDesc3Method = orderByDescMethods.Where(m => 3 == m.GetParameters().Count()).Single()
.MakeGenericMethod(typeof(Foo), typeof(byte[]));
}
protected override Expression VisitMethodCall( MethodCallExpression node )
{
if( node.Method.Equals(_orderBy2Method) ) {
var comparerExpression = Expression.Constant(new ByteArrayComparer());
var arguments = Visit(node.Arguments).AsEnumerable().Concat(new[] { comparerExpression });
var newNode = Expression.Call(node.Object, _orderBy3Method, arguments);
return newNode;
}
else if( node.Method.Equals(_orderByDesc2Method) ) {
var comparerExpression = Expression.Constant(new ByteArrayComparer());
var arguments = Visit(node.Arguments).AsEnumerable().Concat(new[] { comparerExpression });
var newNode = Expression.Call(node.Object, _orderByDesc3Method, arguments);
return newNode;
}
return base.VisitMethodCall(node);
}
private class ByteArrayComparer : IComparer<byte[]>
{
public int Compare( byte[] b1, byte[] b2 )
{
return ((IStructuralComparable)b1).CompareTo(b2, StructuralComparisons.StructuralComparer);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment