Last active
September 20, 2022 21:04
-
-
Save hasheddan/47601ec6f3e94a4e938800ad5e1dca62 to your computer and use it in GitHub Desktop.
Concurrent recursive processing of a tree in Go.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"sync" | |
"time" | |
"github.com/pkg/errors" | |
) | |
type node struct { | |
name string | |
content string | |
children []node | |
} | |
// process is meant to simulate whatever processing is done on a given node. | |
func process(ctx context.Context, n node) (string, error) { | |
if ctx.Err() != nil { | |
return "", errors.Wrap(ctx.Err(), n.name) | |
} | |
return n.content, nil | |
} | |
// traverse takes a node and recursively processes its children. Depending on | |
// the status of the worker pool, this may or may not return before the entire | |
// tree is processed. | |
// Consider the following cases: | |
// 1. Worker pool is saturated: if there are no workers to delegate to, then | |
// we are just doing sequential depth-first traversal with no concurrency. | |
// This call to traverse will block until full tree is processed. | |
// 2. A worker is available for all children of root: if all children are | |
// delegated to workers, then this call to traverse processes the root node | |
// and returns. Processing of the rest of the tree may or may not be | |
// complete. | |
func traverse(ctx context.Context, n node, jobC chan<- node, resC chan<- string, errC chan<- error, rwg *sync.WaitGroup) { | |
defer rwg.Done() | |
// We check if context is done outside of child loop because we don't want | |
// add to the wait group then subsequently remove if context is done. | |
// | |
// Checking <- ctx.Done in the select statement is a common pattern, but if | |
// the worker pool has bandwidth, we are not guaranteed to select the | |
// ctx.Done() as both it and the job channel can progress and ordering is | |
// not honored in select. | |
if ctx.Err() != nil { | |
errC <- errors.Wrap(ctx.Err(), n.name) | |
return | |
} | |
for _, i := range n.children { | |
// Add to the recursion wait group so that we can track when all calls | |
// have returned. | |
rwg.Add(1) | |
select { | |
case jobC <- i: | |
// If worker pool has idle workers, give them work to do. | |
fmt.Printf("delegated node %s to worker pool.\n", i.name) | |
default: | |
// Otherwise we traverse the child ourself. | |
fmt.Printf("handling node %s ourself.\n", i.name) | |
traverse(ctx, i, jobC, resC, errC, rwg) | |
} | |
} | |
// Recursion Base Case: if we either... | |
// 1. Have no children. | |
// 2. Have traversed all children. | |
// 3. Have delegated traversal of all children to worker pool. | |
// 4. Have delegated traversal of some children to the worker pool, and | |
// traversed the rest ourself. | |
// | |
// Then we process the root of this tree or subtree, write response and | |
// error to channels, then return. | |
out, err := process(ctx, n) | |
errC <- err | |
resC <- out | |
} | |
// work is a worker process that reads from the job channel and calls traverse | |
// on nodes until the job channel is closed. | |
func work(ctx context.Context, jobC chan node, resC chan<- string, errC chan<- error, rwg *sync.WaitGroup) { | |
for j := range jobC { | |
traverse(ctx, j, jobC, resC, errC, rwg) | |
} | |
} | |
func main() { | |
root := node{ | |
name: "one", | |
content: "i am one", | |
children: []node{ | |
{ | |
name: "two", | |
content: "i am two", | |
children: []node{ | |
{ | |
name: "six", | |
content: "i am six", | |
}, | |
}, | |
}, | |
{ | |
name: "three", | |
content: "i am three", | |
}, | |
{ | |
name: "four", | |
content: "i am four", | |
children: []node{ | |
{ | |
name: "five", | |
content: "i am five", | |
}, | |
}, | |
}, | |
}, | |
} | |
res, err := handle(root) | |
if err != nil { | |
for _, e := range err { | |
fmt.Printf("found err: %v\n", e) | |
} | |
panic("hit errors") | |
} | |
fmt.Println(res) | |
} | |
func handle(root node) (results []string, errors []error) { | |
// The recursion wait group is responsible for tracking recursion "frames". | |
// It should be incremented every time we recurse, and decremented when we | |
// hit a base case. | |
var recursionWG sync.WaitGroup | |
// The jobs channel is responsible for sending nodes to the worker pool for | |
// processing. It is unbuffered because we only want to send to worker pool | |
// if there are idle workers. | |
jobC := make(chan node) | |
// The result channel is where all the results of processing a node are | |
// sent. It must be processed asynchronously. It should not be closed until | |
// both the recursion wait group and the worker wait group have been | |
// drained. | |
resC := make(chan string) | |
// The error channel is used to signal that an error has occurred. | |
errC := make(chan error) | |
// The context is used for bounding processing time. | |
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | |
defer cancel() | |
// Create a pool of 10 workers, which will process nodes off of jobs channel | |
// until it is closed. | |
for i := 0; i < 10; i++ { | |
go work(ctx, jobC, resC, errC, &recursionWG) | |
} | |
// Start an error consumer. | |
go func() { | |
for e := range errC { | |
if e != nil { | |
errors = append(errors, e) | |
// cancel context to abort all recursion operations. | |
cancel() | |
} | |
} | |
}() | |
// Start a result consumer. Results are read until the result channel is closed. | |
go func() { | |
for r := range resC { | |
results = append(results, r) | |
} | |
}() | |
// Add one to the recursion wait group for the root node. | |
recursionWG.Add(1) | |
// Send the root to the worker pool to begin traversal. We could call | |
// traverse directly, but doing so can lead to not utilizing the worker pool | |
// if the number of children is small (i.e. we end up just doing depth-first | |
// traversal without delegation). | |
jobC <- root | |
// Wait for recursion to complete. | |
recursionWG.Wait() | |
// Close the response channel so that we don't leak our results consumer. | |
close(resC) | |
// Close the error channel so that we don't leak our error consumer. | |
close(errC) | |
// If we have finished recursing, there are no more jobs to be done. Close | |
// the channel to signal to the worker pool that there is no more work. | |
close(jobC) | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment