Created
February 14, 2020 06:53
-
-
Save Horusiath/c1ce452cb2228d6f8affe13c32f8ae71 to your computer and use it in GitHub Desktop.
Demo presenting how to extend C# LINQ syntax over tasks and how to build our own async/await-capable type.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Diagnostics; | |
using System.IO; | |
using System.Runtime.InteropServices; | |
using System.Threading.Tasks; | |
namespace Demo | |
{ | |
public readonly struct Void | |
{ | |
} | |
class Program | |
{ | |
static async Task Main(string[] args) | |
{ | |
await PromiseAsyncExample(); | |
//var line = await LinqSample(); | |
//Console.Write(line); | |
Console.Read(); | |
} | |
static async Promise<Void> PromiseAsyncExample() | |
{ | |
var line = await File.ReadAllTextAsync("sample.txt"); | |
var modified = line + " fibers!"; | |
await File.WriteAllTextAsync("test2.txt", modified); | |
Console.Write(modified); | |
return default; | |
} | |
static Task<string> LinqSample() => | |
from line in File.ReadAllTextAsync("sample.txt") | |
let modified = line + " world" | |
from _ in File.WriteAllTextAsync("test2.txt", modified).ContinueWith(_ => default(Void)) | |
select modified; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Runtime.CompilerServices; | |
using System.Runtime.ExceptionServices; | |
using System.Threading; | |
namespace Demo | |
{ | |
public struct PromiseAsyncMethodBuilder<T> | |
{ | |
private Promise<T>? promise; | |
#region mandatory methods for async state machine builder | |
public static PromiseAsyncMethodBuilder<T> Create() => default; | |
public Promise<T> Task => promise ??= new Promise<T>(); | |
public void SetException(Exception e) => Task.TrySetException(e); | |
public void SetResult(T result) => Task.TrySetResult(result); | |
public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : INotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
awaiter.OnCompleted(stateMachine.MoveNext); | |
} | |
public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) | |
where TAwaiter : ICriticalNotifyCompletion | |
where TStateMachine : IAsyncStateMachine | |
{ | |
awaiter.UnsafeOnCompleted(stateMachine.MoveNext); | |
} | |
public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine | |
{ | |
Action move = stateMachine.MoveNext; | |
ThreadPool.QueueUserWorkItem(_ => | |
{ | |
move(); | |
}); | |
} | |
public void SetStateMachine(IAsyncStateMachine stateMachine) | |
{ | |
// nothing to do | |
} | |
#endregion | |
} | |
public readonly struct PromiseAwaiter<T> : INotifyCompletion | |
{ | |
private readonly Promise<T> fiber; | |
public PromiseAwaiter(Promise<T> fiber) | |
{ | |
this.fiber = fiber; | |
} | |
#region mandatory awaiter methods | |
public bool IsCompleted => fiber.IsCompleted; | |
public T GetResult() => fiber.Result; | |
public void OnCompleted(Action continuation) => fiber.RegisterContinuation(continuation); | |
#endregion | |
} | |
public enum PromiseStatus | |
{ | |
Pending = 0, | |
Success = 1, | |
Failed = 2 | |
} | |
[AsyncMethodBuilder(typeof(PromiseAsyncMethodBuilder<>))] | |
public sealed class Promise<T> | |
{ | |
private PromiseStatus status; | |
private T result; | |
private Action continuation; | |
public Promise(T result) | |
{ | |
this.status = PromiseStatus.Success; | |
this.result = result; | |
} | |
public Promise(Exception exception) | |
{ | |
this.status = PromiseStatus.Failed; | |
this.Exception = exception; | |
} | |
public Promise() | |
{ | |
this.status = PromiseStatus.Pending; | |
} | |
public T Result | |
{ | |
get | |
{ | |
switch (status) | |
{ | |
case PromiseStatus.Success: return result; | |
case PromiseStatus.Failed: | |
ExceptionDispatchInfo.Capture(Exception).Throw(); | |
return default; | |
default: | |
throw new InvalidOperationException("Fiber didn't complete"); | |
} | |
} | |
} | |
public Exception Exception { get; private set; } | |
public bool IsCompleted => status != PromiseStatus.Pending; | |
public PromiseAwaiter<T> GetAwaiter() => new PromiseAwaiter<T>(this); | |
internal bool TrySetResult(T result) | |
{ | |
if (IsCompleted) return false; | |
else | |
{ | |
status = PromiseStatus.Success; | |
this.result = result; | |
this.continuation?.Invoke(); | |
return true; | |
} | |
} | |
internal bool TrySetException(Exception exception) | |
{ | |
if (IsCompleted) return false; | |
else | |
{ | |
status = PromiseStatus.Failed; | |
this.Exception = exception; | |
this.continuation?.Invoke(); | |
return true; | |
} | |
} | |
internal void RegisterContinuation(Action cont) | |
{ | |
if (IsCompleted) | |
cont(); | |
else | |
{ | |
if (this.continuation is null) | |
{ | |
this.continuation = cont; | |
} | |
else | |
{ | |
var prev = this.continuation; | |
this.continuation = () => | |
{ | |
prev(); | |
cont(); | |
}; | |
} | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Threading.Tasks; | |
namespace Demo | |
{ | |
public static class TaskExtensions | |
{ | |
public static Task<T2> Select<T1, T2>(this Task<T1> task, Func<T1, T2> map) => | |
task.ContinueWith(t => { | |
if (t.IsCompletedSuccessfully) return map(t.Result); | |
else throw t.Exception; | |
}); | |
public static Task<T2> SelectMany<T1, T2>(this Task<T1> task, Func<T1, Task<T2>> map) | |
{ | |
var tcs = new TaskCompletionSource<T2>(); | |
task.ContinueWith(t => | |
{ | |
if (t.IsCompletedSuccessfully) | |
map(t.Result).ContinueWith(t2 => | |
{ | |
if (t2.IsCompletedSuccessfully) | |
{ | |
tcs.TrySetResult(t2.Result); | |
} | |
else tcs.TrySetException(t2.Exception); | |
}); | |
else tcs.TrySetException(t.Exception); | |
}); | |
return tcs.Task; | |
} | |
public static Task<TResult> SelectMany<T1, T2, TResult>(this Task<T1> task, Func<T1, Task<T2>> map, Func<T1, T2, TResult> resultSelect) | |
{ | |
var tcs = new TaskCompletionSource<TResult>(); | |
task.ContinueWith(t => | |
{ | |
if (t.IsCompletedSuccessfully) | |
map(t.Result).ContinueWith(t2 => | |
{ | |
if (t2.IsCompletedSuccessfully) | |
{ | |
tcs.TrySetResult(resultSelect(t.Result, t2.Result)); | |
} | |
else tcs.TrySetException(t2.Exception); | |
}); | |
else tcs.TrySetException(t.Exception); | |
}); | |
return tcs.Task; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment