Last active
May 15, 2017 00:59
-
-
Save pmunin/e45692df0b0deee81008aa360e9c612c to your computer and use it in GitHub Desktop.
GraphAggregateUtils - Allows to aggregate node values of acyclic directed graphs, without double counting shared children nodes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using DictionaryUtils; | |
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
using TestGraphUtils; | |
using Xunit; | |
namespace GraphAggregateUtils | |
{ | |
public class GraphAggregateTests | |
{ | |
decimal Cost(TestGraphNode node) | |
{ | |
return node.Data.GetOrAdd("Cost", _ => 0m); | |
} | |
TestGraphNode Cost(TestGraphNode node, decimal value) | |
{ | |
node.Data["Cost"] = value; | |
return node; | |
} | |
TestGraphNode GenerateTestGraph() | |
{ | |
var v0 = new TestGraphNode() { Name = "0" }; | |
var v3 = null as TestGraphNode; | |
v0.Link(v1 => { | |
v1.Name = "1"; | |
Cost(v1, 1); | |
v1.Link(_v3 => | |
{ | |
v3 = _v3; | |
v3.Name = "3"; | |
Cost(v3, 3); | |
v3.Link(v5 => | |
{ | |
v5.Name = "5"; | |
Cost(v5, 5); | |
}) | |
.Link(v6 => | |
{ | |
v6.Name = "6"; | |
Cost(v6, 6); | |
}); | |
}) | |
; | |
}) | |
.Link(v2 => { | |
v2.Name = "2"; | |
Cost(v2, 2); | |
v2 | |
.Link(v3) | |
.Link(v4 => { | |
v4.Name = "4"; | |
Cost(v4, 4); | |
}) | |
; | |
}) | |
; | |
return v0; | |
} | |
[Fact] | |
public void Test1() | |
{ | |
var v0 = GenerateTestGraph(); | |
var aggregate = GraphAggregateUtils.Aggregate( | |
v0, | |
n=>Cost(n), | |
(v1,v2)=>v1+v2, | |
n=>n.Links | |
); | |
Assert.True(aggregate.AggregatedValue == 15 + 2 + 4); | |
} | |
[Fact] | |
public void TestLazy1() | |
{ | |
var v0 = GenerateTestGraph(); | |
var aggregate = TestLazyDFS(v0) as GraphNodeAggregationLazy<decimal>; | |
var res = aggregate.GetTotalValue(); | |
Assert.True(res == 15 + 2 + 4); | |
} | |
private GraphNodeAggregationLazy<decimal> TestLazyDFS(TestGraphNode node) | |
{ | |
return node.Data.GetOrAdd("aggregation", _ => { | |
var agg = new GraphNodeAggregationLazy<decimal>((v1,v2)=>v1+v2); | |
agg.LocalValue = Cost(node); | |
foreach (var childNode in node.Links) | |
agg.AddChild(TestLazyDFS(childNode)); | |
return agg; | |
}); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//Latest version here: https://gist.github.com/e45692df0b0deee81008aa360e9c612c.git | |
using DictionaryUtils; | |
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace GraphAggregateUtils | |
{ | |
/// <summary> | |
/// Allows to aggregate node values of acyclic directed graphs, without double counting shared children nodes | |
/// </summary> | |
public static class GraphAggregateUtils | |
{ | |
public static GraphNodeAggregation<TNode, TValue> Aggregate<TNode, TValue>( | |
TNode rootNode | |
, Func<TNode, TValue> getNodeValue | |
, Func<TValue, TValue, TValue> aggregateValue | |
, Func<TNode, IEnumerable<TNode>> getNodeLinks | |
, IDictionary<TNode, GraphNodeAggregation<TNode, TValue>> preAggregatedNodes = null | |
) | |
{ | |
var args = new GraphNodeAggregation<TNode, TValue>.Args() | |
{ | |
Node = rootNode, | |
AggregateValue = aggregateValue, | |
GetLinkedNodes = getNodeLinks, | |
GetNodeValue = getNodeValue, | |
AggregationByNode = preAggregatedNodes | |
}; | |
return GetOrAddAggregation(args); | |
} | |
public static GraphNodeAggregation<TNode, TValue> GetOrAddAggregation<TNode, TValue>(GraphNodeAggregation<TNode, TValue>.Args args) | |
{ | |
if (args.AggregationByNode == null) | |
args.AggregationByNode = new Dictionary<TNode, GraphNodeAggregation<TNode, TValue>>(); | |
if (args.VisitedNodes == null) | |
args.VisitedNodes = new HashSet<TNode>(); | |
return args.AggregationByNode.GetOrAdd(args.Node,node=> new GraphNodeAggregation<TNode, TValue>(args)); | |
} | |
internal static void AddRange<T>(HashSet<T> hashSet, IEnumerable<T> itemsToAdd) | |
{ | |
foreach (var item in itemsToAdd) | |
{ | |
hashSet.Add(item); | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace GraphAggregateUtils | |
{ | |
public class GraphNodeAggregation<TNode, TValue> | |
{ | |
public GraphNodeAggregation(Args args) | |
{ | |
var res = this; | |
var node = args.Node; | |
if (args.AggregationByNode != null && args.AggregationByNode.TryGetValue(node, out var existingAgg)) | |
throw new InvalidOperationException("This node is already aggregated"); | |
res.Node = node; | |
res.LocalValue = args.GetNodeValue(node); | |
res.AggregatedValue = res.LocalValue; | |
foreach (var childNode in args.GetLinkedNodes(node) ?? Enumerable.Empty<TNode>()) | |
{ | |
var childArgs = new GraphNodeAggregation<TNode, TValue>.Args() | |
{ | |
AggregateValue = args.AggregateValue, | |
AggregationByNode = args.AggregationByNode, | |
GetLinkedNodes = args.GetLinkedNodes, | |
GetNodeValue = args.GetNodeValue, | |
Node = childNode | |
}; | |
var childNodeAggregation = GraphAggregateUtils.GetOrAddAggregation(childArgs); | |
res.AggregateChild(childNodeAggregation, args, true); | |
} | |
} | |
public void AggregateChild(GraphNodeAggregation<TNode, TValue> childAggregation, Args args, bool isDirectChild) | |
{ | |
var childNode = childAggregation.Node; | |
var nodeAggr = this; | |
if (nodeAggr.AllChildrenAggregates.Contains(childAggregation)) | |
return; | |
if (!nodeAggr.AllChildrenAggregates.Overlaps(childAggregation.AllChildrenAggregates)) | |
{ | |
nodeAggr.AggregatedValue = args.AggregateValue(nodeAggr.AggregatedValue, childAggregation.AggregatedValue); | |
} | |
else//Has Overlaps | |
{ | |
//adding childLocalValue | |
//and then aggregate recursively grandchildren | |
nodeAggr.AggregatedValue = args.AggregateValue(nodeAggr.AggregatedValue, childAggregation.LocalValue); | |
foreach (var grandChildAgg in childAggregation.DirectChildrenAggregates) | |
nodeAggr.AggregateChild(grandChildAgg, args, false); | |
} | |
//nodeAggr.AllChildrenNodes.Add(childNode); | |
//if(isDirectChild) nodeAggr.DirectChildrenNodes.Add(childNode); | |
GraphAggregateUtils.AddRange(nodeAggr.AllChildrenAggregates | |
, childAggregation.AllChildrenAggregates.Prepend(childAggregation)); | |
if (isDirectChild) nodeAggr.DirectChildrenAggregates.Add(childAggregation); | |
} | |
public TNode Node; | |
public TValue LocalValue; | |
public TValue AggregatedValue; | |
public HashSet<GraphNodeAggregation<TNode, TValue>> DirectChildrenAggregates = new HashSet<GraphNodeAggregation<TNode, TValue>>(); | |
public HashSet<GraphNodeAggregation<TNode, TValue>> AllChildrenAggregates = new HashSet<GraphNodeAggregation<TNode, TValue>>(); | |
public class Args | |
{ | |
public IDictionary<TNode, GraphNodeAggregation<TNode, TValue>> AggregationByNode; | |
public TNode Node; | |
public Func<TNode, IEnumerable<TNode>> GetLinkedNodes; | |
public Func<TValue, TValue, TValue> AggregateValue; | |
public Func<TNode, TValue> GetNodeValue; | |
/// <summary> | |
/// Required for Acyclic check | |
/// </summary> | |
public HashSet<TNode> VisitedNodes; | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace GraphAggregateUtils | |
{ | |
public abstract partial class GraphNodeAggregationLazyBase<TAggregatableValue> | |
{ | |
protected abstract TAggregatableValue AggregateValues(TAggregatableValue value1, TAggregatableValue value2); | |
Dictionary<object, object> data = null; | |
public IDictionary<object, object> Data | |
{ | |
get | |
{ | |
return data ?? (data = new Dictionary<object, object>()); | |
} | |
} | |
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue>> Children => directChildren; | |
TAggregatableValue localValue; | |
public TAggregatableValue LocalValue { get { return localValue; } set { localValue = value; OnTotalChanged(); } } | |
protected void OnTotalChanged() | |
{ | |
totalAggregatedLazy = null; | |
} | |
HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> directChildren | |
= new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>>(); | |
public void AddChild(GraphNodeAggregationLazyBase<TAggregatableValue> child) | |
{ | |
directChildren.Add(child); | |
OnTotalChanged(); | |
} | |
public void RemoveChild(GraphNodeAggregationLazyBase<TAggregatableValue> child) | |
{ | |
directChildren.Remove(child); | |
OnTotalChanged(); | |
} | |
Lazy<(TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren)> totalAggregatedLazy = null; | |
protected Lazy<(TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren)> AggregatedLazy | |
{ | |
get | |
{ | |
return totalAggregatedLazy ?? | |
( | |
totalAggregatedLazy | |
= new Lazy<(TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren)>(CalculateAggregated) | |
); | |
} | |
} | |
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue>> GetAllChildren() | |
{ | |
return AggregatedLazy.Value.allChildren; | |
} | |
public TAggregatableValue GetTotalValue() | |
{ | |
return AggregatedLazy.Value.totalValue; | |
} | |
protected (TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren) CalculateAggregated() | |
{ | |
var res = (totalValue: this.LocalValue, allChildren: new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>>()); | |
foreach (var child in directChildren) | |
{ | |
AggregateAppend(child, ref res); | |
} | |
return res; | |
} | |
protected void AggregateAppend(GraphNodeAggregationLazyBase<TAggregatableValue> childToAggregate, ref (TAggregatableValue totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue>> allChildren) aggregateResult) | |
{ | |
if (aggregateResult.allChildren.Contains(childToAggregate)) | |
return; | |
var hasOverlaps = aggregateResult.allChildren.Overlaps(childToAggregate.GetAllChildren()); | |
if (!hasOverlaps) | |
aggregateResult.totalValue = AggregateValues(aggregateResult.totalValue, childToAggregate.GetTotalValue()); | |
else //Has overlaps | |
{ | |
//adding childLocalValue | |
//and then aggregate recursively grandchildren | |
aggregateResult.totalValue = AggregateValues( | |
aggregateResult.totalValue | |
, childToAggregate.LocalValue | |
); | |
foreach (var grandChildAgg in childToAggregate.Children) | |
AggregateAppend(grandChildAgg, ref aggregateResult); | |
} | |
GraphAggregateUtils.AddRange(aggregateResult.allChildren | |
, childToAggregate.GetAllChildren().Prepend(childToAggregate)); | |
} | |
} | |
public partial class GraphNodeAggregationLazy<TAggregatableValue> | |
: GraphNodeAggregationLazyBase<TAggregatableValue> | |
{ | |
public GraphNodeAggregationLazy(Func<TAggregatableValue, TAggregatableValue, TAggregatableValue> aggregate) | |
{ | |
this.AggregateDelegate = aggregate; | |
} | |
public Func<TAggregatableValue, TAggregatableValue, TAggregatableValue> AggregateDelegate { get; private set; } | |
protected override TAggregatableValue AggregateValues(TAggregatableValue value1, TAggregatableValue value2) | |
{ | |
return AggregateDelegate(value1, value2); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using System.Threading.Tasks; | |
namespace GraphAggregateUtils | |
{ | |
public abstract partial class GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate> | |
{ | |
protected abstract void AggregateValues(ref TAccumulate accumulatorMutable, TAggregatableValue valueImmutable); | |
protected abstract void AggregateAccumulators(ref TAccumulate accumulatorMutable, TAccumulate accumulatorImmutable); | |
protected virtual TAccumulate AggregateSeed() | |
{ | |
return default(TAccumulate); | |
} | |
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> Children => directChildren; | |
TAggregatableValue localValue; | |
public TAggregatableValue LocalValue { get { return localValue; } set { localValue = value; OnTotalChanged(); } } | |
protected void OnTotalChanged() | |
{ | |
totalAggregatedLazy = null; | |
} | |
HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>> directChildren | |
= new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>>(); | |
public void AddChild(GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate> child) | |
{ | |
directChildren.Add(child); | |
OnTotalChanged(); | |
} | |
public void RemoveChild(GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate> child) | |
{ | |
directChildren.Remove(child); | |
OnTotalChanged(); | |
} | |
Lazy<(TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>> allChildren)> totalAggregatedLazy = null; | |
protected Lazy<(TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue,TAccumulate>> allChildren)> AggregatedLazy | |
{ | |
get | |
{ | |
return totalAggregatedLazy ?? | |
( | |
totalAggregatedLazy | |
= new Lazy<(TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> allChildren)>(CalculateAggregated) | |
); | |
} | |
} | |
public IEnumerable<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> GetAllChildren() | |
{ | |
return AggregatedLazy.Value.allChildren; | |
} | |
public TAccumulate GetTotalValue() | |
{ | |
return AggregatedLazy.Value.totalValue; | |
} | |
protected (TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> allChildren) CalculateAggregated() | |
{ | |
var res = | |
( | |
totalValue: AggregateSeed() | |
, allChildren: new HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>>() | |
); | |
AggregateValues(ref res.totalValue, this.LocalValue); | |
foreach (var child in directChildren) | |
{ | |
AggregateAppend(child, ref res); | |
} | |
return res; | |
} | |
protected void AggregateAppend(GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate> childToAggregate | |
, ref (TAccumulate totalValue, HashSet<GraphNodeAggregationLazyBase<TAggregatableValue, TAccumulate>> allChildren) aggregateResult | |
) | |
{ | |
if (aggregateResult.allChildren.Contains(childToAggregate)) | |
return; | |
var hasOverlaps = aggregateResult.allChildren.Overlaps(childToAggregate.GetAllChildren()); | |
if (!hasOverlaps) | |
AggregateAccumulators(ref aggregateResult.totalValue, childToAggregate.GetTotalValue()); | |
else //Has overlaps | |
{ | |
//adding childLocalValue | |
//and then aggregate recursively grandchildren | |
AggregateValues(ref aggregateResult.totalValue , childToAggregate.LocalValue); | |
foreach (var grandChildAgg in childToAggregate.Children) | |
AggregateAppend(grandChildAgg, ref aggregateResult); | |
} | |
GraphAggregateUtils.AddRange(aggregateResult.allChildren | |
, childToAggregate.GetAllChildren().Prepend(childToAggregate)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment