Skip to content

Instantly share code, notes, and snippets.

@garuma
Last active May 13, 2019 03:38
Show Gist options
  • Save garuma/5d2527b697c271cc243fa1ddf853cb2d to your computer and use it in GitHub Desktop.
Save garuma/5d2527b697c271cc243fa1ddf853cb2d to your computer and use it in GitHub Desktop.
F# computation expressions in C#. Migrated to https://github.com/garuma/Neteril.ComputationExpression
packages uti id title platforms
id version
MathNet.Numerics.Core
3.17.0
id version
OxyPlot.Core
2.0.0-unstable1013
com.xamarin.workbook
744cd41b-7a5b-44a4-b17d-80a4297f0fc3
computation-expression
DotNetCore

F# computation expressions in C#

A generic re-implementation of F# computation expressions in C# by (ab)using async/await (by @jeremie_laval)

TL;DR lifting the tricks of my previous attempt at borrowing F# computation expression concepts in C# into a generic form that can be reused for other builder types.

This code utilizes the F# definition of compution expression builders (or more precisely the subset of Bind/Return/Zero members) for a given monad type and plug it into a customized async/await method builder and awaiter.

Basically the idea is to go from this kind of F#:

let divideBy bottom top =
	if bottom = 0
	then None
	else Some(top/bottom)

type MaybeBuilder() =
	member this.Bind(m, f) = Option.bind f m
	member this.Return(x) = Some x

let maybe = new MaybeBuilder()

let divideByWorkflow =
	maybe {
		let! a = 120 |> divideBy 2
		let! b = a |> divideBy 2
		let! c = b |> divideBy 2
		return c
	}

To this kind of C#:

Option<int> TryDivide (int up, int down)
	=> return down == 0 ? None<int>.Value : Some.Of (up / down);

class MaybeBuilder : IMonadExpressionBuilder
{
	public IMonad<T> Bind<T> (IMonad<T> m, Func<T, IMonad<T>> f)
	{
		switch ((Option<T>)m) {
			case Some<T> some: return f (some.Item);
			case None<T> none: return none;
			default: return None<T>.Value;
		}
	}
	public IMonad<T> Return<T> (T v) => Some.Of (v);
	public IMonad<T> Zero<T> () => None<T>.Value;
}

ComputationExpression.Run<int, Option<int>> (new MaybeBuilder (), async () => {
	var val1 = await TryDivide (120, 2);
	var val2 = await TryDivide (val1, 2);
	var val3 = await TryDivide (val2, 2);
	return val3;
})

In this example the code is very similar with C#'s await becoming the equivalent of F#'s let!/do!. To see it running in practice, scroll down at the bottom of the workbook and execute the code.

Enjoy!

using System;
using System.Runtime.CompilerServices;

Plumbing

Implementation of the computation expression machinery

[AsyncMethodBuilder (typeof (MonadAsyncMethodBuilder<>))]
public abstract class IMonad<T>
{
	public MonadAwaiter<T> GetAwaiter () => new MonadAwaiter<T> (this);
}

// Follows the type signature of https://docs.microsoft.com/en-us/dotnet/fsharp/language-reference/computation-expressions
public interface IMonadExpressionBuilder
{
	IMonad<T> Bind<U, T> (IMonad<U> m, Func<U, IMonad<T>> f);
	IMonad<T> Return<T> (T v);
	IMonad<T> Zero<T> ();
	IMonad<T> Combine<T> (IMonad<T> m, IMonad<T> n);
}

public static class ComputationExpression
{
	[ThreadStatic]
	internal static IMonadExpressionBuilder CurrentBuilder = null;

	public static TMonad Run<T, TMonad> (IMonadExpressionBuilder builder, Func<IMonad<T>> body) where TMonad : IMonad<T>
	{
		try {
			CurrentBuilder = builder;
			return (TMonad)body ();
		} finally {
			CurrentBuilder = null;
		}
	}

	public static CombineAwaitable<T> Yield<T> (T value)
	{
		return new CombineAwaitable<T> (value);
	}
}

// Our awaiter simply acts as a mutable holder for the bind state
public class MonadAwaiter<T> : INotifyCompletion
{
	IMonad<T> monad;
	T result;

	public MonadAwaiter (IMonad<T> m)
	{
		this.monad = m;
	}

	// Helpers to get/set the intermediate results of Bind
	internal IMonad<T> CurrentMonad => monad;
	internal void SetNextStep (T value) => this.result = value;

	public T GetResult () => result;

	/* We never want to turn on the async machinery optimization
	 * and instead continuously create continuations
	 */
	public bool IsCompleted => false;

	public void OnCompleted (Action continuation)
	{
		/* We never need to execute the continuation cause
		 * the async method builder drives everything.
		 */
	}
}

public struct CombineAwaitable<T>
{
	T yieldedValue;
	public CombineAwaitable (T yieldedValue) => this.yieldedValue = yieldedValue;
	public CombineAwaiter<T> GetAwaiter () => new CombineAwaiter<T> (yieldedValue);
}

public class CombineAwaiter<T> : INotifyCompletion
{
	T yieldedValue;
	public CombineAwaiter (T yieldedValue) => this.yieldedValue = yieldedValue;
	internal T YieldedValue => yieldedValue;

	public T GetResult () => YieldedValue;

	public bool IsCompleted => false;

	public void OnCompleted (Action continuation)
	{
	}
}

public class MonadAsyncMethodBuilder<T>
{
	IMonadExpressionBuilder builder;
	IMonad<T> finalResult;

	public static MonadAsyncMethodBuilder<T> Create ()
	{
		var builder = ComputationExpression.CurrentBuilder;
		if (builder == null)
			throw new NotSupportedException($"Computation expression can only be run from {nameof(ComputationExpression)}.{nameof(ComputationExpression.Run)}");
		return new MonadAsyncMethodBuilder<T> (builder);
	}

	public MonadAsyncMethodBuilder (IMonadExpressionBuilder builder)
	{
		this.builder = builder;
		finalResult = builder.Zero<T> ();
	}

	public void Start<TStateMachine> (ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine
	{
		stateMachine.MoveNext ();
	}

	public IMonad<T> Task => finalResult;

	public void SetStateMachine (IAsyncStateMachine stateMachine) { }
	public void SetResult (T result)
	{
		finalResult = builder.Return<T> (result);
	}

	public void SetException (Exception ex) { Console.WriteLine (ex); }

	public void AwaitOnCompleted<TAwaiter, TStateMachine> (ref TAwaiter awaiter, ref TStateMachine stateMachine)
		where TAwaiter : INotifyCompletion
		where TStateMachine : IAsyncStateMachine
	{
		if (!typeof (TAwaiter).IsGenericType)
			throw new InvalidOperationException ("Invalid awaiter given");
		
		if (typeof(TAwaiter).GetGenericTypeDefinition () == typeof (MonadAwaiter<>)) {
			/* Unfortunately we can't infer the U of MonadAwaiter<U>
			* from the constructed TAwaiter type that's given to us
			* so we have to resort to good old reflection.
			*/
			var monadUType = typeof (TAwaiter).GetGenericArguments ()[0];
			var monadAwaiter = awaiter;
			var stateMachineCopy = (IAsyncStateMachine)stateMachine;
			GetType ().GetMethod ("ProcessBind", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic)
				.MakeGenericMethod (monadUType)
				.Invoke (this, new object[] { monadAwaiter, stateMachineCopy});
			return;
		}
		if (typeof(TAwaiter).GetGenericTypeDefinition () == typeof (CombineAwaiter<>)) {
			var yieldAwaiter = (CombineAwaiter<T>)(object)awaiter;
			var m = builder.Return (yieldAwaiter.YieldedValue);
			stateMachine.MoveNext ();
			this.finalResult = builder.Combine (m, finalResult);
			return;
		}

		throw new InvalidOperationException ("Invalid awaiter given");
	}

	void ProcessBind<U> (MonadAwaiter<U> monadAwaiter, IAsyncStateMachine stateMachine)
	{
		var monad = monadAwaiter.CurrentMonad;
		var machineState = GetMachineState (stateMachine);
		var userMonad = builder.Bind<U, T> (monad, value =>
		{
			/* If we are called that means we keep the
			 * control of the execution flow, no need
			 * to produce a monad instance of our own
			 * at that stage since it will be fed in
			 * later.
			 */
			monadAwaiter.SetNextStep (value);
			if (GetMachineState (stateMachine) != machineState)
				ResetStateMachine (stateMachine, machineState, monadAwaiter);
			stateMachine.MoveNext ();
			return finalResult;
		});
		if (userMonad != null)
			this.finalResult = userMonad;
	}

	int GetMachineState (IAsyncStateMachine stateMachine)
	{
		var stateField = stateMachine
			.GetType ()
			.GetField ("<>1__state", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public);
		return (int)stateField.GetValue (stateMachine);
	}

	void ResetStateMachine (IAsyncStateMachine stateMachine, int state, object awaiter)
	{
		var stateField = stateMachine
			.GetType ()
			.GetField ("<>1__state", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public);
		stateField.SetValue (stateMachine, state);
		var awaiterField = stateMachine
			.GetType ()
			.GetField ("<>u__1", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic);
		awaiterField.SetValue (stateMachine, awaiter);
	}

	public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine> (ref TAwaiter awaiter, ref TStateMachine stateMachine)
		where TAwaiter : ICriticalNotifyCompletion
		where TStateMachine : IAsyncStateMachine
	{
		throw new NotSupportedException ();
	}
}

The Maybe workflow with Option<T>

Boilerplate implementation of a Option monad type and its expression builder based on the code above

abstract class Option<T> : IMonad<T> { }

sealed class None<T> : Option<T> { public static readonly None<T> Value = new None<T> (); }
sealed class Some<T> : Option<T>
{
	public readonly T Item;
	public Some (T item) => Item = item;
	public static explicit operator T (Some<T> option) => option.Item;
}

static class Some
{
	public static Some<T> Of<T> (T value) => new Some<T> (value);
}

class MaybeBuilder : IMonadExpressionBuilder
{
	IMonad<T> IMonadExpressionBuilder.Bind<U, T> (IMonad<U> m, Func<U, IMonad<T>> f)
	{
		switch ((Option<U>)m) {
			case Some<U> some:
				return f (some.Item);
			case None<U> none:
			default:
				return None<T>.Value;
		}
	}

	IMonad<T> IMonadExpressionBuilder.Return<T> (T v) => Some.Of (v);

	IMonad<T> IMonadExpressionBuilder.Zero<T> () => None<T>.Value;

	IMonad<T> IMonadExpressionBuilder.Combine<T> (IMonad<T> m, IMonad<T> n) => throw new NotSupportedException ();
}

Some extra helpers to work with Option<T>

static Option<int> TryDivide (int up, int down)
{
	Console.WriteLine ($"Trying to execute division {up}/{down}");
	if (down == 0)
		return None<int>.Value;

	return Some.Of (up / down);
}

static void PrintResult<T> (Option<T> maybe)
{
	switch (maybe)
	{
		case None<T> n:
			Console.WriteLine ("None");
			break;
		case Some<T> s:
			Console.WriteLine ($"Some {(T)s}");
			break;
	}
}

In this example we are simulating the maybe computation expression which allows to short-circuit a series of statement based on the Option monad outcome.

// We don't want any "this call is not awaited" warnings since we know what we are doing(?)
#pragma warning disable 4014

Console.WriteLine ("## Good example");
var good = ComputationExpression.Run<int, Option<int>> (new MaybeBuilder (), async () => {
	var val1 = await TryDivide (120, 2);
	var val2 = await TryDivide (val1, 2);
	var val3 = await TryDivide (val2, 2);

	return val3;
});
PrintResult (good);

Console.WriteLine ();
Console.WriteLine ("## Bad example");
var bad = ComputationExpression.Run<int, Option<int>> (new MaybeBuilder (), async () => {
	var val1 = await TryDivide (120, 2);
	var val2 = await TryDivide (val1, 0);
	var val3 = await TryDivide (val2, 2);

	return val3;
});
PrintResult (bad);

If you run that final block of code your output should like this:

## Good example

Trying to execute division 120/2
Trying to execute division 60/2
Trying to execute division 30/2
Some 15

## Bad example

Trying to execute division 120/2
Trying to execute division 60/0
None

Haskell State monad

In Haskell pure world, state is not allowed to be mutated. Instead the intention is reproduced via the State monad that allows a piece of state to be propagated at the same time as intermediary results.

Now let's borrow some State monad fun from the Haskell tutorial at https://en.wikibooks.org/wiki/Haskell/Understanding_monads/State

This is somewhat cheating in our case because where in Haskell it makes sense to pass the random value as state to be used as the next random seed, in C# it’s not really necessary since the before state is already encapsulated in the Random class.

public class State<TState, TValue> : IMonad<TValue>
{
	Func<TState, (TValue, TState)> stateProcessor;

	public State (Func<TState, (TValue, TState)> stateProcessor)
	{
		this.stateProcessor = stateProcessor;
	}

	public (TValue value, TState state) RunState (TState state) => stateProcessor (state);
}

public static State<TState, TValue> Put<TState, TValue> (TState state)
 	=> new State<TState, TValue> (_ => (default, state));
public static State<TState, TState> Get<TState> ()
	=> new State<TState, TState> (s => (s, s));

public static TValue EvalState<TState, TValue> (State<TState, TValue> stateMonad, TState state)
	=> stateMonad.RunState (state).value;
public static TState ExecState<TState, TValue> (State<TState, TValue> stateMonad, TState state)
	=> stateMonad.RunState (state).state;

public class StateBuilder<TState> : IMonadExpressionBuilder
{
	IMonad<T> IMonadExpressionBuilder.Bind<U, T> (IMonad<U> m, Func<U, IMonad<T>> f)
	{
		var previousStateMonad = ((State<TState, U>)m);
		return new State<TState, T> (s => {
			var (value, newState) = previousStateMonad.RunState (s);
			var nextMonad = (State<TState, T>)f (value);
			return nextMonad.RunState (newState);
		});
	}

	IMonad<T> IMonadExpressionBuilder.Return<T> (T v) => new State<TState, T> (s => (v, s));

	IMonad<T> IMonadExpressionBuilder.Zero<T> () => new State<TState, T> (s => (default, s));

	IMonad<T> IMonadExpressionBuilder.Combine<T> (IMonad<T> m, IMonad<T> n) => throw new NotSupportedException ();
}
#pragma warning disable 4014

static (int random, Random generator) RandomR ((int low, int high) interval, Random initialGenerator)
	=> (initialGenerator.Next (interval.low, interval.high), new Random ());

var rollDie = ComputationExpression.Run<int, State<Random, int>> (new StateBuilder<Random> (), async () => {
	var generator = await Get<Random> ();
	var (value, newGenerator) = RandomR ((1, 6), generator);
	await Put<Random, int> (newGenerator);
	return value;
});
EvalState<Random, int> (rollDie, new Random ());

Re-creating yield state machine

We can also end up re-creating our good old yield return but with async/await and some help from the extra Combine operation of our computation expression builder. The result is somewhat more verbose but it’s doable.

public class EnumerableMonad<T> : IMonad<T>, IEnumerable<T>
{
	IEnumerable<T> seed;
	public EnumerableMonad (IEnumerable<T> seed) => this.seed = seed;
	public IEnumerator<T> GetEnumerator () => seed.GetEnumerator ();
	System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator () => seed.GetEnumerator ();
}

public class EnumerableBuilder : IMonadExpressionBuilder
{
	IMonad<T> IMonadExpressionBuilder.Bind<U, T> (IMonad<U> m, Func<U, IMonad<T>> f)
	{
		var previousEnumerableMonad = (EnumerableMonad<U>)m;
		return new EnumerableMonad<T> (previousEnumerableMonad.SelectMany (u => (EnumerableMonad<T>)f (u)));
	}

	IMonad<T> IMonadExpressionBuilder.Return<T> (T v) => new EnumerableMonad<T> (Enumerable.Repeat (v, 1));

	IMonad<T> IMonadExpressionBuilder.Zero<T> () => new EnumerableMonad<T> (Enumerable.Empty<T> ());

	IMonad<T> IMonadExpressionBuilder.Combine<T> (IMonad<T> m, IMonad<T> n)
	{
		var enumerableMonad1 = (EnumerableMonad<T>)m;
		var enumerableMonad2 = (EnumerableMonad<T>)n;
		return new EnumerableMonad<T> (enumerableMonad1.Concat (enumerableMonad2));
	}
}
#pragma warning disable 4014

var result = ComputationExpression.Run<int, EnumerableMonad<int>> (new EnumerableBuilder (), async () => {
	var item = await new EnumerableMonad<int> (new [] { 1, 2, 3 });
	var item2 = await new EnumerableMonad<int> (new [] { 100, 200 });
	// We want back a enumeration containing the concatenation of (item, item2, item1 * item2)
	// for all successive values of item1 and item2
	await ComputationExpression.Yield (item);
	await ComputationExpression.Yield (item2);
	return item * item2;
});
Console.WriteLine (string.Join (", ", result.Select (i => i.ToString ())));

Probabilities as monads

Probability distribution can be represented as monads and thus chain together. The example used here comes courtesy of https://www.chrisstucchio.com/blog/2016/probability\_the\_monad.html

Note that for execution time constraints, sampling count has been drastically reduced so that the histogram at the end can be generated in a reasonable amount of time (still can take a minute or two). Ultimately that means actual results are probably not that correct.

#r "MathNet.Numerics.Core"

using MathNet.Numerics.Distributions;
using MathNet.Numerics.Random;

public abstract class Probability<T> : IMonad<T>
{
	public abstract double Prob (T t);
}

public abstract class RandomSamplingProbablity : Probability<double>
{
	public abstract double Draw ();

	public override double Prob (double t)
	{ 
		const int NumSamples = 5000;
		var found = Enumerable
			.Range (0, NumSamples)
			.Where (i => Math.Abs (Draw () - t) < 0.001) // Shallow equality
			.Count ();
		var prob = ((double)found) / NumSamples;
		return prob;
	}
}

public class DiscreteDistributionProbability : Probability<int>
{
	IDiscreteDistribution distribution;

	public DiscreteDistributionProbability (IDiscreteDistribution d) => this.distribution = d;

	public override double Prob (int t) => distribution.Probability (t);
}

public class ContinuousDistributionProbability : RandomSamplingProbablity
{
	IContinuousDistribution distribution;

	public ContinuousDistributionProbability (IContinuousDistribution d) => this.distribution = d;

	public override double Draw () => distribution.Sample ();
}

public class ComposedProbability<T> : Probability<T>
{
	Func<T, double> prob;
	public ComposedProbability (Func<T, double> prob) => this.prob = prob;

	public override double Prob (T t) => prob (t);
}

// Returns "all" possible values of a given type
static IEnumerable<T> SpaceOf<T> ()
{
	if (typeof (T) == typeof (int))
		// The universal cast "operator"
		return (IEnumerable<T>)(object)Enumerable.Range (0, 100);
	if (typeof (T) == typeof (double))
		return (IEnumerable<T>)(object)DoubleRange (0, 1, 0.05);
	throw new NotSupportedException ();
}

static IEnumerable<double> DoubleRange (double from, double to, double step)
{
	while (from < to) {
		yield return from;
		from += step;
	}
}

public class ProbabilityBuilder : IMonadExpressionBuilder
{
	IMonad<T> IMonadExpressionBuilder.Bind<U, T> (IMonad<U> m, Func<U, IMonad<T>> f)
	{
		Probability<U> p = (Probability<U>)m;
		return new ComposedProbability<T> (t => {
			double probSum = 0;
			foreach (var u in SpaceOf<U> ()) {
				probSum += p.Prob (u) * ((Probability<T>)f (u)).Prob (t);
			}
			return probSum;
		});
	}

	IMonad<T> IMonadExpressionBuilder.Return<T> (T v)
		=> new ComposedProbability<T> (t => EqualityComparer<T>.Default.Equals (t, v) ? 1 : 0);

	IMonad<T> IMonadExpressionBuilder.Zero<T> () => new ComposedProbability<T> (_ => 1);

	IMonad<T> IMonadExpressionBuilder.Combine<T> (IMonad<T> m, IMonad<T> n) => throw new NotSupportedException ();
}
#r "OxyPlot"
using OxyPlot;

var result = ComputationExpression.Run<double, Probability<double>> (new ProbabilityBuilder (), async () => {
	var l = await new ContinuousDistributionProbability (new Beta (51, 151));
	var n = await new DiscreteDistributionProbability (new Binomial (l, 100));
	return n / 100.0;
});

var plotModel = new PlotModel {
	Title = "Empirical conversion rate",
	PlotType = PlotType.XY
};
var serie = new OxyPlot.Series.LinearBarSeries ();
serie.ItemsSource = new List<DataPoint> (DoubleRange (0, 1, 0.05).Select (i => new DataPoint (i, result.Prob (i))));
plotModel.Series.Add (serie);

plotModel;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment