Skip to content

Instantly share code, notes, and snippets.

@sean-gilliam
Created August 19, 2024 20:23
Show Gist options
  • Save sean-gilliam/b188956675d37e69c79999546ba0b2ae to your computer and use it in GitHub Desktop.
Save sean-gilliam/b188956675d37e69c79999546ba0b2ae to your computer and use it in GitHub Desktop.

Entity Framework Core - Mocking DBContexts with async methods.

When mocking DbContexts for unit testing that contain async methods, an exception is thrown with the following verbiage:

System.InvalidOperationException: 'The provider for the source 'IQueryable' doesn't implement 'IAsyncQueryProvider'. Only providers that implement 'IAsyncQueryProvider' can be used for Entity Framework asynchronous operations.'

In order to avoid this exception create the following classes:

namespace Your.Namespace.Here;

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

// for reference:
// https://learn.microsoft.com/en-us/ef/ef6/fundamentals/testing/mocking?redirectedfrom=MSDN#testing-with-async-queries
// https://stackoverflow.com/a/40491640
// https://stackoverflow.com/a/58314109

internal class TestAsyncQueryProvider<TEntity> : IAsyncQueryProvider
{
    private readonly IQueryProvider _inner;

    internal TestAsyncQueryProvider(IQueryProvider inner) => _inner = inner;

    public IQueryable CreateQuery(Expression expression) => new TestAsyncEnumerable<TEntity>(expression);
    public IQueryable<TElement> CreateQuery<TElement>(Expression expression) => new TestAsyncEnumerable<TElement>(expression);
    public object Execute(Expression expression) => _inner.Execute(expression);
    public TResult Execute<TResult>(Expression expression) => _inner.Execute<TResult>(expression);
    public IAsyncEnumerable<TResult> ExecuteAsync<TResult>(Expression expression) => new TestAsyncEnumerable<TResult>(expression);
    public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken) => Task.FromResult(Execute<TResult>(expression));
    TResult IAsyncQueryProvider.ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
    {
        //return Execute<TResult>(expression);
        var expectedResultType = typeof(TResult).GetGenericArguments()[0];
        var executionResult = typeof(IQueryProvider)
                             .GetMethod(
                                  name: nameof(IQueryProvider.Execute),
                                  genericParameterCount: 1,
                                  types: new[] { typeof(Expression) })
                             ?.MakeGenericMethod(expectedResultType)
                             .Invoke(this, new[] { expression });

        return (TResult)typeof(Task).GetMethod(nameof(Task.FromResult))
                                    ?.MakeGenericMethod(expectedResultType)
                                     .Invoke(null, new[] { executionResult });
    }
}

internal class TestAsyncEnumerable<T> : EnumerableQuery<T>, IAsyncEnumerable<T>, IQueryable<T>
{
    public TestAsyncEnumerable(IEnumerable<T> enumerable) : base(enumerable) { }
    public TestAsyncEnumerable(Expression expression) : base(expression) { }

    public IAsyncEnumerator<T> GetEnumerator() => new TestAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
    public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) => GetEnumerator();
    IQueryProvider IQueryable.Provider => new TestAsyncQueryProvider<T>(this);
}

internal class TestAsyncEnumerator<T> : IAsyncEnumerator<T>
{
    private readonly IEnumerator<T> _inner;
    public T Current => _inner.Current;

    public TestAsyncEnumerator(IEnumerator<T> inner) => _inner = inner;

    public void Dispose() => _inner.Dispose();
    public Task<bool> MoveNext(CancellationToken cancellationToken) => Task.FromResult(_inner.MoveNext());
    public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(_inner.MoveNext());
    public ValueTask DisposeAsync()
    {
        _inner.Dispose();
        return new ValueTask();
    }
}

To use this class in your mock setup, model the setup like so:

Note: Notification class is just an example.

private void Setup()
{
	...
	
	// Data for mock
	_notifications = new List<Notification>
	{
		... data for mock ...
	}.AsQueryable();

	// Create mock queryable dbset
	var mockSet = new Mock<DbSet<Notification>>();

	// async setup
	mockSet.As<IAsyncEnumerable<Notification>>()
	    .Setup(m => m.GetAsyncEnumerator(default))
	    .Returns(new TestAsyncEnumerator<Notification>(_notifications.GetEnumerator()));

	mockSet.As<IQueryable<Notification>>()
	    .Setup(m => m.Provider)
	    .Returns(new TestAsyncQueryProvider<Notification>(_notifications.Provider));

	// sync setup
	//mockSet.As<IQueryable<Notification>>().Setup(m => m.Provider).Returns(_notifications.Provider);

	mockSet.As<IQueryable<Notification>>().Setup(m => m.Expression).Returns(_notifications.Expression);
	mockSet.As<IQueryable<Notification>>().Setup(m => m.ElementType).Returns(_notifications.ElementType);
	mockSet.As<IQueryable<Notification>>().Setup(m => m.GetEnumerator()).Returns(_notifications.GetEnumerator());

	// Setup mocks
	_mockContext.Setup(c => c.Notifications).Returns(mockSet.Object);

	...
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment