Last active
July 23, 2019 20:39
-
-
Save daveaglick/158eccf94a3b98d9e705739a89127d40 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Collections.Immutable; | |
using System.Linq; | |
using System.Runtime.CompilerServices; | |
using System.Threading; | |
using System.Threading.Tasks; | |
namespace DocumentQuery | |
{ | |
class Program | |
{ | |
static async Task Main(string[] args) | |
{ | |
CancellationTokenSource cts = new CancellationTokenSource(); | |
ExecutionContext context = new ExecutionContext | |
{ | |
CancellationToken = cts.Token, | |
AsParallel = true, | |
InputDocuments = ImmutableArray.Create<IDocument>( | |
new Document { Content = "ABC" }, | |
new Document { Content = "AYZ" }, | |
new Document { Content = "ZYZ" }) | |
}; | |
// Non-parallel non-async chained where | |
List<IDocument> result = | |
context | |
.Where(x => x.Content.StartsWith("A")) | |
.Where(x => x.Content.EndsWith("Z")) | |
.ToList(); | |
// Non-parallel async then sync chained where | |
List<IDocument> result2 = | |
(await context | |
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A")))) | |
.Where(x => x.Content.EndsWith("Z")) | |
.ToList(); | |
// Non-parallel async chained where | |
List<IDocument> result3 = | |
(await context | |
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A"))) | |
.WhereAsync(x => Task.FromResult(x.Content.EndsWith("Z")))) | |
.ToList(); | |
context.AsParallel = true; | |
// Parallel non-async chained where | |
List<IDocument> result4 = | |
context | |
.Where(x => x.Content.StartsWith("A")) | |
.Where(x => x.Content.EndsWith("Z")) | |
.ToList(); | |
// Parallel async then sync chained where | |
List<IDocument> result5 = | |
(await context | |
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A")))) | |
.Where(x => x.Content.EndsWith("Z")) | |
.ToList(); | |
// Parallel async chained where | |
List<IDocument> result6 = | |
(await context | |
.WhereAsync(x => Task.FromResult(x.Content.StartsWith("A"))) | |
.WhereAsync(x => Task.FromResult(x.Content.EndsWith("Z")))) | |
.ToList(); | |
} | |
} | |
public static class IDocumentQueryExtensions | |
{ | |
public static IDocumentQuery Where(this IDocumentQuery query, Func<IDocument, bool> predicate) | |
{ | |
return query.GetQuery(source => | |
GetParallelQuery(source, query)?.Where(CancelAndTrace(predicate, query.Context)) | |
?? source.Where(CancelAndTrace(predicate, query.Context))); | |
} | |
public static AwaitableDocumentQuery WhereAsync(this IDocumentQuery query, Func<IDocument, Task<bool>> predicate) => | |
new AwaitableDocumentQuery(query, WhereAsync(query.Context, predicate)); | |
public static AwaitableDocumentQuery WhereAsync(this AwaitableDocumentQuery awaitableQuery, Func<IDocument, Task<bool>> predicate) => | |
new AwaitableDocumentQuery(awaitableQuery, WhereAsync(awaitableQuery.Context, predicate)); | |
private static Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> WhereAsync(IExecutionContext context, Func<IDocument, Task<bool>> predicate) => async source => | |
{ | |
if (context.AsParallel) | |
{ | |
return (await Task.WhenAll(source.Select(x => Task.Run(async () => (Document: x, Result: await CancelAndTraceAsync(predicate, context)(x)))))) | |
.Where(x => x.Result) | |
.Select(x => x.Document); | |
} | |
List<IDocument> results = new List<IDocument>(); | |
foreach (IDocument item in source) | |
{ | |
if (await CancelAndTraceAsync(predicate, context)(item)) | |
{ | |
results.Add(item); | |
} | |
} | |
return results; | |
}; | |
private static Func<IDocument, TResult> CancelAndTrace<TResult>(Func<IDocument, TResult> func, IExecutionContext context) => | |
x => | |
{ | |
context.CancellationToken.ThrowIfCancellationRequested(); | |
try | |
{ | |
return func(x); | |
} | |
catch (Exception ex) | |
{ | |
// query.Context.Trace(...) | |
Console.WriteLine(ex); | |
throw; | |
} | |
}; | |
private static Func<IDocument, Task<TResult>> CancelAndTraceAsync<TResult>(Func<IDocument, Task<TResult>> func, IExecutionContext context) => | |
x => | |
{ | |
context.CancellationToken.ThrowIfCancellationRequested(); | |
try | |
{ | |
return func(x); | |
} | |
catch (Exception ex) | |
{ | |
// query.Context.Trace(...) | |
Console.WriteLine(ex); | |
throw; | |
} | |
}; | |
private static ParallelQuery<IDocument> GetParallelQuery(IEnumerable<IDocument> source, IDocumentQuery query) => | |
(source as ParallelQuery<IDocument>) ?? (query.Context.AsParallel ? source.AsParallel().AsOrdered().WithCancellation(query.Context.CancellationToken) : null); | |
} | |
public class AwaitableDocumentQuery | |
{ | |
private readonly Task<IDocumentQuery> _task; | |
internal AwaitableDocumentQuery(IDocumentQuery previousQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func) | |
{ | |
_task = GetCompletionTaskAsync(previousQuery, func); | |
Context = previousQuery.Context; | |
} | |
internal AwaitableDocumentQuery(AwaitableDocumentQuery awaitableQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func) | |
{ | |
_task = GetCompletionTaskAsync(awaitableQuery, func); | |
Context = awaitableQuery.Context; | |
} | |
private static async Task<IDocumentQuery> GetCompletionTaskAsync(IDocumentQuery previousQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func) => | |
new CompletedDocumentQuery(previousQuery.Context, await func(previousQuery)); | |
private static async Task<IDocumentQuery> GetCompletionTaskAsync(AwaitableDocumentQuery awaitableQuery, Func<IEnumerable<IDocument>, Task<IEnumerable<IDocument>>> func) => | |
new CompletedDocumentQuery(awaitableQuery.Context, await func(await awaitableQuery)); | |
public DocumentQueryAwaiter GetAwaiter() => new DocumentQueryAwaiter(_task); | |
public IExecutionContext Context { get; } | |
private class CompletedDocumentQuery : IDocumentQuery | |
{ | |
private readonly IEnumerable<IDocument> _results; | |
public CompletedDocumentQuery(IExecutionContext context, IEnumerable<IDocument> results) | |
{ | |
Context = context; | |
_results = results; | |
} | |
public IExecutionContext Context { get; } | |
public IDocumentQuery GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func) => new DocumentQuery(this, func); | |
public IEnumerator<IDocument> GetEnumerator() => _results.GetEnumerator(); | |
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); | |
} | |
} | |
public class DocumentQueryAwaiter : INotifyCompletion | |
{ | |
private readonly Task<IDocumentQuery> _task; | |
internal DocumentQueryAwaiter(Task<IDocumentQuery> task) | |
{ | |
_task = task; | |
} | |
public void OnCompleted(Action continuation) => new Task(continuation).Start(); | |
public bool IsCompleted => _task.IsCompleted; | |
public IDocumentQuery GetResult() => _task.Result; | |
} | |
public interface IDocumentQuery : IEnumerable<IDocument> | |
{ | |
IDocumentQuery GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func); | |
IExecutionContext Context { get; } | |
} | |
internal class DocumentQuery : IDocumentQuery | |
{ | |
private readonly IDocumentQuery _previousQuery; | |
private readonly Func<IEnumerable<IDocument>, IEnumerable<IDocument>> _func; | |
internal DocumentQuery(IDocumentQuery query, Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func) | |
{ | |
_previousQuery = query; | |
_func = func; | |
} | |
public IDocumentQuery GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func) => new DocumentQuery(this, func); | |
public IExecutionContext Context => _previousQuery.Context; | |
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); | |
public IEnumerator<IDocument> GetEnumerator() | |
{ | |
return _func(_previousQuery).GetEnumerator(); | |
} | |
} | |
public interface IExecutionContext : IReadOnlyList<IDocument>, IDocumentQuery | |
{ | |
bool AsParallel { get; } | |
CancellationToken CancellationToken { get; } | |
} | |
public class ExecutionContext : IExecutionContext | |
{ | |
public bool AsParallel { get; set; } | |
public ImmutableArray<IDocument> InputDocuments { get; set; } | |
public CancellationToken CancellationToken { get; set; } | |
// IDocumentQuery | |
IDocumentQuery IDocumentQuery.GetQuery(Func<IEnumerable<IDocument>, IEnumerable<IDocument>> func) | |
{ | |
return new DocumentQuery(this, func); | |
} | |
IExecutionContext IDocumentQuery.Context => this; | |
// IReadOnlyList<IDocument> | |
public IDocument this[int index] => InputDocuments[index]; | |
public int Count => InputDocuments.Length; | |
public IEnumerator<IDocument> GetEnumerator() => ((IEnumerable<IDocument>)InputDocuments).GetEnumerator(); | |
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable<IDocument>)InputDocuments).GetEnumerator(); | |
} | |
public interface IDocument | |
{ | |
string Content { get; set; } | |
} | |
public class Document : IDocument | |
{ | |
public string Content { get; set; } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment