Skip to content

Instantly share code, notes, and snippets.

@hasheddan
Last active September 20, 2022 21:04
Show Gist options
  • Save hasheddan/47601ec6f3e94a4e938800ad5e1dca62 to your computer and use it in GitHub Desktop.
Save hasheddan/47601ec6f3e94a4e938800ad5e1dca62 to your computer and use it in GitHub Desktop.
Concurrent recursive processing of a tree in Go.
package main
import (
"context"
"fmt"
"sync"
"time"
"github.com/pkg/errors"
)
type node struct {
name string
content string
children []node
}
// process is meant to simulate whatever processing is done on a given node.
func process(ctx context.Context, n node) (string, error) {
if ctx.Err() != nil {
return "", errors.Wrap(ctx.Err(), n.name)
}
return n.content, nil
}
// traverse takes a node and recursively processes its children. Depending on
// the status of the worker pool, this may or may not return before the entire
// tree is processed.
// Consider the following cases:
// 1. Worker pool is saturated: if there are no workers to delegate to, then
// we are just doing sequential depth-first traversal with no concurrency.
// This call to traverse will block until full tree is processed.
// 2. A worker is available for all children of root: if all children are
// delegated to workers, then this call to traverse processes the root node
// and returns. Processing of the rest of the tree may or may not be
// complete.
func traverse(ctx context.Context, n node, jobC chan<- node, resC chan<- string, errC chan<- error, rwg *sync.WaitGroup) {
defer rwg.Done()
// We check if context is done outside of child loop because we don't want
// add to the wait group then subsequently remove if context is done.
//
// Checking <- ctx.Done in the select statement is a common pattern, but if
// the worker pool has bandwidth, we are not guaranteed to select the
// ctx.Done() as both it and the job channel can progress and ordering is
// not honored in select.
if ctx.Err() != nil {
errC <- errors.Wrap(ctx.Err(), n.name)
return
}
for _, i := range n.children {
// Add to the recursion wait group so that we can track when all calls
// have returned.
rwg.Add(1)
select {
case jobC <- i:
// If worker pool has idle workers, give them work to do.
fmt.Printf("delegated node %s to worker pool.\n", i.name)
default:
// Otherwise we traverse the child ourself.
fmt.Printf("handling node %s ourself.\n", i.name)
traverse(ctx, i, jobC, resC, errC, rwg)
}
}
// Recursion Base Case: if we either...
// 1. Have no children.
// 2. Have traversed all children.
// 3. Have delegated traversal of all children to worker pool.
// 4. Have delegated traversal of some children to the worker pool, and
// traversed the rest ourself.
//
// Then we process the root of this tree or subtree, write response and
// error to channels, then return.
out, err := process(ctx, n)
errC <- err
resC <- out
}
// work is a worker process that reads from the job channel and calls traverse
// on nodes until the job channel is closed.
func work(ctx context.Context, jobC chan node, resC chan<- string, errC chan<- error, rwg *sync.WaitGroup) {
for j := range jobC {
traverse(ctx, j, jobC, resC, errC, rwg)
}
}
func main() {
root := node{
name: "one",
content: "i am one",
children: []node{
{
name: "two",
content: "i am two",
children: []node{
{
name: "six",
content: "i am six",
},
},
},
{
name: "three",
content: "i am three",
},
{
name: "four",
content: "i am four",
children: []node{
{
name: "five",
content: "i am five",
},
},
},
},
}
res, err := handle(root)
if err != nil {
for _, e := range err {
fmt.Printf("found err: %v\n", e)
}
panic("hit errors")
}
fmt.Println(res)
}
func handle(root node) (results []string, errors []error) {
// The recursion wait group is responsible for tracking recursion "frames".
// It should be incremented every time we recurse, and decremented when we
// hit a base case.
var recursionWG sync.WaitGroup
// The jobs channel is responsible for sending nodes to the worker pool for
// processing. It is unbuffered because we only want to send to worker pool
// if there are idle workers.
jobC := make(chan node)
// The result channel is where all the results of processing a node are
// sent. It must be processed asynchronously. It should not be closed until
// both the recursion wait group and the worker wait group have been
// drained.
resC := make(chan string)
// The error channel is used to signal that an error has occurred.
errC := make(chan error)
// The context is used for bounding processing time.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create a pool of 10 workers, which will process nodes off of jobs channel
// until it is closed.
for i := 0; i < 10; i++ {
go work(ctx, jobC, resC, errC, &recursionWG)
}
// Start an error consumer.
go func() {
for e := range errC {
if e != nil {
errors = append(errors, e)
// cancel context to abort all recursion operations.
cancel()
}
}
}()
// Start a result consumer. Results are read until the result channel is closed.
go func() {
for r := range resC {
results = append(results, r)
}
}()
// Add one to the recursion wait group for the root node.
recursionWG.Add(1)
// Send the root to the worker pool to begin traversal. We could call
// traverse directly, but doing so can lead to not utilizing the worker pool
// if the number of children is small (i.e. we end up just doing depth-first
// traversal without delegation).
jobC <- root
// Wait for recursion to complete.
recursionWG.Wait()
// Close the response channel so that we don't leak our results consumer.
close(resC)
// Close the error channel so that we don't leak our error consumer.
close(errC)
// If we have finished recursing, there are no more jobs to be done. Close
// the channel to signal to the worker pool that there is no more work.
close(jobC)
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment