Skip to content

Instantly share code, notes, and snippets.

@daveaglick
Last active July 23, 2019 20:39
Show Gist options
  • Save daveaglick/158eccf94a3b98d9e705739a89127d40 to your computer and use it in GitHub Desktop.
Save daveaglick/158eccf94a3b98d9e705739a89127d40 to your computer and use it in GitHub Desktop.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace DocumentQuery
{
class Program
{
static async Task Main(string[] args)
{
CancellationTokenSource cts = new CancellationTokenSource();
ExecutionContext context = new ExecutionContext
{
CancellationToken = cts.Token,
AsParallel = true,
InputDocuments = ImmutableArray.Create<IDocument>(
new Document { Content = "ABC" },
new Document { Content = "AYZ" },
new Document { Content = "ZYZ" })
};
// Non-parallel non-async chained where
List<IDocument> result =
context
.Where(x => x.Content.StartsWith("A"))
.Where(x => x.Content.EndsWith("Z"))
.ToList();
// Non-parallel async then sync chained where
List<IDocument> result2 =
(await context
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A"))))
.Where(x => x.Content.EndsWith("Z"))
.ToList();
// Non-parallel async chained where
List<IDocument> result3 =
(await context
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A")))
.WhereAsync(x => Task.FromResult(x.Content.EndsWith("Z"))))
.ToList();
context.AsParallel = true;
// Parallel non-async chained where
List<IDocument> result4 =
context
.Where(x => x.Content.StartsWith("A"))
.Where(x => x.Content.EndsWith("Z"))
.ToList();
// Parallel async then sync chained where
List<IDocument> result5 =
(await context
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A"))))
.Where(x => x.Content.EndsWith("Z"))
.ToList();
// Parallel async chained where
List<IDocument> result6 =
(await context
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A")))
.WhereAsync(x => Task.FromResult(x.Content.EndsWith("Z"))))
.ToList();
}
}
public static class IDocumentQueryExtensions
{
public static IDocumentQuery Where(this IDocumentQuery query, Func<IDocument, bool> predicate)
{
return query.GetQuery(source =>
GetParallelQuery(source, query)?.Where(CancelAndTrace(predicate, query.Context))
?? source.Where(CancelAndTrace(predicate, query.Context)));
}
public static AwaitableDocumentQuery WhereAsync(this IDocumentQuery query, Func<IDocument, Task<bool>> predicate) =>
new AwaitableDocumentQuery(query, WhereAsync(query.Context, predicate));
public static AwaitableDocumentQuery WhereAsync(this AwaitableDocumentQuery awaitableQuery, Func<IDocument, Task<bool>> predicate) =>
new AwaitableDocumentQuery(awaitableQuery, WhereAsync(awaitableQuery.Context, predicate));
private static Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> WhereAsync(IExecutionContext context, Func<IDocument, Task<bool>> predicate) => async source =>
{
if (context.AsParallel)
{
return (await Task.WhenAll(source.Select(x => Task.Run(async () => (Document: x, Result: await CancelAndTraceAsync(predicate, context)(x))))))
.Where(x => x.Result)
.Select(x => x.Document);
}
List<IDocument> results = new List<IDocument>();
foreach (IDocument item in source)
{
if (await CancelAndTraceAsync(predicate, context)(item))
{
results.Add(item);
}
}
return results;
};
private static Func<IDocument, TResult> CancelAndTrace<TResult>(Func<IDocument, TResult> func, IExecutionContext context) =>
x =>
{
context.CancellationToken.ThrowIfCancellationRequested();
try
{
return func(x);
}
catch (Exception ex)
{
// query.Context.Trace(...)
Console.WriteLine(ex);
throw;
}
};
private static Func<IDocument, Task<TResult>> CancelAndTraceAsync<TResult>(Func<IDocument, Task<TResult>> func, IExecutionContext context) =>
x =>
{
context.CancellationToken.ThrowIfCancellationRequested();
try
{
return func(x);
}
catch (Exception ex)
{
// query.Context.Trace(...)
Console.WriteLine(ex);
throw;
}
};
private static ParallelQuery<IDocument> GetParallelQuery(IEnumerable<IDocument> source, IDocumentQuery query) =>
(source as ParallelQuery<IDocument>) ?? (query.Context.AsParallel ? source.AsParallel().AsOrdered().WithCancellation(query.Context.CancellationToken) : null);
}
public class AwaitableDocumentQuery
{
private readonly Task<IDocumentQuery> _task;
internal AwaitableDocumentQuery(IDocumentQuery previousQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func)
{
_task = GetCompletionTaskAsync(previousQuery, func);
Context = previousQuery.Context;
}
internal AwaitableDocumentQuery(AwaitableDocumentQuery awaitableQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func)
{
_task = GetCompletionTaskAsync(awaitableQuery, func);
Context = awaitableQuery.Context;
}
private static async Task<IDocumentQuery> GetCompletionTaskAsync(IDocumentQuery previousQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func) =>
new CompletedDocumentQuery(previousQuery.Context, await func(previousQuery));
private static async Task<IDocumentQuery> GetCompletionTaskAsync(AwaitableDocumentQuery awaitableQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func) =>
new CompletedDocumentQuery(awaitableQuery.Context, await func(await awaitableQuery));
public DocumentQueryAwaiter GetAwaiter() => new DocumentQueryAwaiter(_task);
public IExecutionContext Context { get; }
private class CompletedDocumentQuery : IDocumentQuery
{
private readonly IEnumerable<IDocument> _results;
public CompletedDocumentQuery(IExecutionContext context, IEnumerable<IDocument> results)
{
Context = context;
_results = results;
}
public IExecutionContext Context { get; }
public IDocumentQuery GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func) => new DocumentQuery(this, func);
public IEnumerator<IDocument> GetEnumerator() => _results.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
public class DocumentQueryAwaiter : INotifyCompletion
{
private readonly Task<IDocumentQuery> _task;
internal DocumentQueryAwaiter(Task<IDocumentQuery> task)
{
_task = task;
}
public void OnCompleted(Action continuation) => new Task(continuation).Start();
public bool IsCompleted => _task.IsCompleted;
public IDocumentQuery GetResult() => _task.Result;
}
public interface IDocumentQuery : IEnumerable<IDocument>
{
IDocumentQuery GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func);
IExecutionContext Context { get; }
}
internal class DocumentQuery : IDocumentQuery
{
private readonly IDocumentQuery _previousQuery;
private readonly Func<IEnumerable<IDocument>, IEnumerable<IDocument>> _func;
internal DocumentQuery(IDocumentQuery query, Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func)
{
_previousQuery = query;
_func = func;
}
public IDocumentQuery GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func) => new DocumentQuery(this, func);
public IExecutionContext Context => _previousQuery.Context;
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public IEnumerator<IDocument> GetEnumerator()
{
return _func(_previousQuery).GetEnumerator();
}
}
public interface IExecutionContext : IReadOnlyList<IDocument>, IDocumentQuery
{
bool AsParallel { get; }
CancellationToken CancellationToken { get; }
}
public class ExecutionContext : IExecutionContext
{
public bool AsParallel { get; set; }
public ImmutableArray<IDocument> InputDocuments { get; set; }
public CancellationToken CancellationToken { get; set; }
// IDocumentQuery
IDocumentQuery IDocumentQuery.GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func)
{
return new DocumentQuery(this, func);
}
IExecutionContext IDocumentQuery.Context => this;
// IReadOnlyList<IDocument>
public IDocument this[int index] => InputDocuments[index];
public int Count => InputDocuments.Length;
public IEnumerator<IDocument> GetEnumerator() => ((IEnumerable<IDocument>)InputDocuments).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable<IDocument>)InputDocuments).GetEnumerator();
}
public interface IDocument
{
string Content { get; set; }
}
public class Document : IDocument
{
public string Content { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment