Skip to content

Instantly share code, notes, and snippets.

@maartenba
Created March 30, 2012 17:22
Show Gist options
  • Save maartenba/2253114 to your computer and use it in GitHub Desktop.
Save maartenba/2253114 to your computer and use it in GitHub Desktop.
A fresh draft of CachedQueryable<T>
class Program
{
static void Main(string[] args)
{
List<Person> source = new List<Person>();
source.Add(new Person { Id = 1, Name = "Maarten" });
source.Add(new Person { Id = 2, Name = "Xavier" });
var cache = new List<Person>();
var result1 = new CachedQueryable<Person>(cache, source.AsQueryable(), new string[] { "Id" }).Where(p => p.Id == 1).FirstOrDefault();
Console.WriteLine(result1.Name);
var result2 = new CachedQueryable<Person>(cache, source.AsQueryable(), new string[] { "Id" }).Where(p => p.Id == 1).FirstOrDefault();
Console.WriteLine(result2.Name);
var result3 = new CachedQueryable<Person>(cache, source.AsQueryable(), new string[] { "Id" }).Where(p => p.Id == 2 && p.Name == "Xavier").FirstOrDefault();
Console.WriteLine(result3.Name);
Console.Read();
}
}
class CachedQueryable<T>
: IQueryable<T>
{
private readonly ICollection<T> entityCache;
private readonly IQueryable<T> delegateQueryable;
private readonly string[] identifierProperties;
public CachedQueryable(ICollection<T> entityCache, IQueryable<T> delegateQueryable, string[] identifierProperties)
{
this.entityCache = entityCache;
this.delegateQueryable = delegateQueryable;
this.identifierProperties = identifierProperties;
}
public IEnumerator<T> GetEnumerator()
{
return new QueryProvider(entityCache, delegateQueryable, identifierProperties)
.Execute<IEnumerable<T>>(Expression).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public Expression Expression
{
get { return delegateQueryable.Expression; }
}
public Type ElementType
{
get { return delegateQueryable.ElementType; }
}
public IQueryProvider Provider
{
get { return new QueryProvider(entityCache, delegateQueryable, identifierProperties); }
}
private class QueryProvider : IQueryProvider
{
private readonly ICollection<T> entityCache;
private readonly IQueryable<T> delegateQueryable;
private readonly string[] identifierProperties;
public QueryProvider(ICollection<T> entityCache, IQueryable<T> delegateQueryable, string[] identifierProperties)
{
this.entityCache = entityCache;
this.delegateQueryable = delegateQueryable;
this.identifierProperties = identifierProperties;
}
public IQueryable CreateQuery(Expression expression)
{
return CreateQuery<object>(expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new CachedQueryable<TElement>((ICollection<TElement>)entityCache, delegateQueryable.Provider.CreateQuery<TElement>(expression), identifierProperties);
}
public object Execute(Expression expression)
{
return Execute<object>(expression);
}
public TResult Execute<TResult>(Expression expression)
{
// Only use cache for expressions that explicitly look for identifiers
var whereAnalyzer = new WhereAnalyzer(identifierProperties);
bool queriesOnlyIdentifierProperties = whereAnalyzer.QueriesOnlyIdentifierProperties(expression);
if (queriesOnlyIdentifierProperties)
{
// Find whereExpressions
var whereFinder = new WhereFinder();
var whereExpressions = whereFinder.GetWhere(expression);
// Run whereExpressions on local cache
var resultsFromCache = entityCache.Cast<TResult>().AsQueryable();
foreach (var whereExpression in whereExpressions)
{
var lambda = (Expression<Func<TResult, bool>>)(((UnaryExpression)whereExpression.Arguments[1]).Operand);
resultsFromCache = resultsFromCache.Where(lambda);
}
if (resultsFromCache.Any())
{
return resultsFromCache.Cast<TResult>().AsQueryable().Provider.Execute<TResult>(expression);
}
}
// Fetch noncached results and cache them
var nonCachedResults = delegateQueryable.Cast<T>().ToList();
lock (entityCache)
{
foreach (var result in nonCachedResults)
{
if (!entityCache.Contains(result))
{
entityCache.Add(result);
}
}
}
// Return
return nonCachedResults.Cast<TResult>().AsQueryable().Provider.Execute<TResult>(expression);
}
}
private class WhereFinder
: ExpressionVisitor
{
private IList<MethodCallExpression> whereExpressions = new List<MethodCallExpression>();
private string[] whereExpressionMethodNames = new string[] { "Where" };
public IList<MethodCallExpression> GetWhere(Expression expression)
{
Visit(expression);
return whereExpressions;
}
protected override Expression VisitMethodCall(MethodCallExpression expression)
{
if (whereExpressionMethodNames.Contains(expression.Method.Name))
{
whereExpressions.Add(expression);
}
Visit(expression.Arguments[0]);
return expression;
}
}
private class WhereAnalyzer
: ExpressionVisitor
{
private string[] identifierProperties = null;
private string[] whereExpressionMethodNames = new string[] { "Where" };
private int whereDepth;
private int numberOfWhereExpressions = 0;
private int numberOfWhereExpressionsWithIdentifier = 0;
public WhereAnalyzer(string[] identifierProperties)
{
this.identifierProperties = identifierProperties;
}
public bool QueriesOnlyIdentifierProperties(Expression expression)
{
numberOfWhereExpressions = 0;
numberOfWhereExpressionsWithIdentifier = 0;
Visit(expression);
return numberOfWhereExpressions == numberOfWhereExpressionsWithIdentifier;
}
protected override Expression VisitMethodCall(MethodCallExpression expression)
{
if (whereExpressionMethodNames.Contains(expression.Method.Name))
{
whereDepth++;
numberOfWhereExpressions++;
}
Expression visitedExpression = base.VisitMethodCall(expression);
if (whereExpressionMethodNames.Contains(expression.Method.Name))
{
whereDepth--;
}
return visitedExpression;
}
protected override Expression VisitBinary(BinaryExpression expression)
{
if (expression.NodeType == ExpressionType.Equal)
{
AnalyzeWhereForIdentifierProperties(expression.Left, expression.Right);
AnalyzeWhereForIdentifierProperties(expression.Right, expression.Left);
}
return base.VisitBinary(expression);
}
private void AnalyzeWhereForIdentifierProperties(Expression expression1, Expression expression2)
{
var memberExpression = expression1 as MemberExpression;
if (memberExpression == null)
{
return;
}
if (whereDepth > 0)
{
if (identifierProperties.Contains(memberExpression.Member.Name))
{
numberOfWhereExpressionsWithIdentifier++;
}
else
{
numberOfWhereExpressions++; // AndAlso, for example, may be playing along too
}
}
}
}
}
}
class Person
{
public int Id { get; set; }
public string Name { get; set; }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment