-
-
Save yrong/0402e4242ea82549c6790a999cad8a64 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Companion code to https://medium.com/statuscode/pipeline-patterns-in-go-a37bb3a7e61d | |
// | |
// To run: | |
// go get github.com/pkg/errors | |
// go run -race pipeline_demo.go | |
// | |
package main | |
import ( | |
"context" | |
"fmt" | |
"math/rand" | |
"strconv" | |
"sync" | |
"time" | |
"github.com/pkg/errors" | |
) | |
// MergeErrors merges multiple channels of errors. | |
// Based on https://blog.golang.org/pipelines. | |
func MergeErrors(cs ...<-chan error) <-chan error { | |
var wg sync.WaitGroup | |
// We must ensure that the output channel has the capacity to hold as many errors | |
// as there are error channels. This will ensure that it never blocks, even | |
// if WaitForPipeline returns early. | |
out := make(chan error, len(cs)) | |
// Start an output goroutine for each input channel in cs. output | |
// copies values from c to out until c is closed, then calls wg.Done. | |
output := func(c <-chan error) { | |
for n := range c { | |
out <- n | |
} | |
wg.Done() | |
} | |
wg.Add(len(cs)) | |
for _, c := range cs { | |
go output(c) | |
} | |
// Start a goroutine to close out once all the output goroutines are | |
// done. This must start after the wg.Add call. | |
go func() { | |
wg.Wait() | |
close(out) | |
}() | |
return out | |
} | |
// WaitForPipeline waits for results from all error channels. | |
// It returns early on the first error. | |
func WaitForPipeline(errs ...<-chan error) error { | |
errc := MergeErrors(errs...) | |
for err := range errc { | |
if err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
// minimalPipelineStage shows the elements that every pipeline stage should have. | |
// All stages should accept a context for cancellation. | |
// All stages should return a channel of errors to report any error produced after this function returns. | |
// All stages should return an error to report any error produced before this function returns. | |
// Any required input parameters should follow ctx and any required outputs should precede | |
// the errors channel. | |
// Inputs can be ordinary objects (e.g. a list of strings), channels of objects, or gRPC input streams. | |
// Outputs can be ordinary objects, channels of objects, or gRPC output streams. | |
func minimalPipelineStage(ctx context.Context) (<-chan error, error) { | |
errc := make(chan error, 1) | |
go func() { | |
defer close(errc) | |
// Do something useful here. | |
}() | |
return errc, nil | |
} | |
func lineListSource(ctx context.Context, lines ...string) ( | |
<-chan string, <-chan error, error) { | |
if len(lines) == 0 { | |
// Handle an error that occurs before the goroutine begins. | |
return nil, nil, errors.Errorf("no lines provided") | |
} | |
out := make(chan string) | |
errc := make(chan error, 1) | |
go func() { | |
defer close(out) | |
defer close(errc) | |
for lineIndex, line := range lines { | |
if line == "" { | |
// Handle an error that occurs during the goroutine. | |
errc <- errors.Errorf("line %v is empty", lineIndex+1) | |
return | |
} | |
// Send the data to the output channel but return early | |
// if the context has been cancelled. | |
select { | |
case out <- line: | |
case <-ctx.Done(): | |
return | |
} | |
} | |
}() | |
return out, errc, nil | |
} | |
func lineParser(ctx context.Context, base int, in <-chan string) ( | |
<-chan int64, <-chan error, error) { | |
if base < 2 { | |
// Handle an error that occurs before the goroutine begins. | |
return nil, nil, errors.Errorf("invalid base %v", base) | |
} | |
out := make(chan int64) | |
errc := make(chan error, 1) | |
go func() { | |
defer close(out) | |
defer close(errc) | |
for line := range in { | |
n, err := strconv.ParseInt(line, base, 64) | |
if err != nil { | |
// Handle an error that occurs during the goroutine. | |
errc <- err | |
return | |
} | |
// Send the data to the output channel but return early | |
// if the context has been cancelled. | |
select { | |
case out <- n: | |
case <-ctx.Done(): | |
return | |
} | |
} | |
}() | |
return out, errc, nil | |
} | |
func splitter(ctx context.Context, in <-chan int64) ( | |
<-chan int64, <-chan int64, <-chan error, error) { | |
out1 := make(chan int64) | |
out2 := make(chan int64) | |
errc := make(chan error, 1) | |
go func() { | |
defer close(out1) | |
defer close(out2) | |
defer close(errc) | |
for n := range in { | |
// Send the data to the output channel 1 but return early | |
// if the context has been cancelled. | |
select { | |
case out1 <- n: | |
case <-ctx.Done(): | |
return | |
} | |
// Send the data to the output channel 2 but return early | |
// if the context has been cancelled. | |
select { | |
case out2 <- n: | |
case <-ctx.Done(): | |
return | |
} | |
} | |
}() | |
return out1, out2, errc, nil | |
} | |
func squarer(ctx context.Context, in <-chan int64) (<-chan int64, <-chan error, error) { | |
out := make(chan int64) | |
errc := make(chan error, 1) | |
go func() { | |
defer close(out) | |
defer close(errc) | |
for n := range in { | |
// Send the data to the output channel but return early | |
// if the context has been cancelled. | |
select { | |
case out <- n * n: | |
case <-ctx.Done(): | |
return | |
} | |
} | |
}() | |
return out, errc, nil | |
} | |
func sink(ctx context.Context, in <-chan int64) ( | |
<-chan error, error) { | |
errc := make(chan error, 1) | |
go func() { | |
defer close(errc) | |
for n := range in { | |
if n >= 100 { | |
// Handle an error that occurs during the goroutine. | |
errc <- errors.Errorf("number %v is too large", n) | |
return | |
} | |
fmt.Printf("sink: %v\n", n) | |
} | |
}() | |
return errc, nil | |
} | |
func runSimplePipeline(base int, lines []string) error { | |
fmt.Printf("runSimplePipeline: base=%v, lines=%v\n", base, lines) | |
ctx, cancelFunc := context.WithCancel(context.Background()) | |
defer cancelFunc() | |
var errcList []<-chan error | |
// Source pipeline stage. | |
linec, errc, err := lineListSource(ctx, lines...) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Transformer pipeline stage. | |
numberc, errc, err := lineParser(ctx, base, linec) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Sink pipeline stage. | |
errc, err = sink(ctx, numberc) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
fmt.Println("Pipeline started. Waiting for pipeline to complete.") | |
return WaitForPipeline(errcList...) | |
} | |
func runComplexPipeline(base int, lines []string) error { | |
fmt.Printf("runComplexPipeline: base=%v, lines=%v\n", base, lines) | |
ctx, cancelFunc := context.WithCancel(context.Background()) | |
defer cancelFunc() | |
var errcList []<-chan error | |
// Source pipeline stage. | |
linec, errc, err := lineListSource(ctx, lines...) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Transformer pipeline stage 1. | |
numberc, errc, err := lineParser(ctx, base, linec) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Transformer pipeline stage 2. | |
numberc1, numberc2, errc, err := splitter(ctx, numberc) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Transformer pipeline stage 3. | |
numberc3, errc, err := squarer(ctx, numberc1) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Sink pipeline stage 1. | |
errc, err = sink(ctx, numberc3) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Sink pipeline stage 2. | |
errc, err = sink(ctx, numberc2) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
fmt.Println("Pipeline started. Waiting for pipeline to complete.") | |
return WaitForPipeline(errcList...) | |
} | |
func randomNumberSource(ctx context.Context, seed int64) (<-chan string, <-chan error, error) { | |
out := make(chan string) | |
errc := make(chan error, 1) | |
random := rand.New(rand.NewSource(seed)) | |
go func() { | |
defer close(out) | |
defer close(errc) | |
for { | |
n := random.Intn(100) | |
line := fmt.Sprintf("%v", n) | |
// Send the data to the output channel but return if the context has been cancelled. | |
select { | |
case out <- line: | |
case <-ctx.Done(): | |
return | |
} | |
time.Sleep(1 * time.Second) | |
} | |
}() | |
return out, errc, nil | |
} | |
func runPipelineWithTimeout() error { | |
fmt.Printf("runPipelineWithTimeout\n") | |
ctx, cancelFunc := context.WithCancel(context.Background()) | |
defer cancelFunc() | |
var errcList []<-chan error | |
// Source pipeline stage. | |
linec, errc, err := randomNumberSource(ctx, 3) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Transformer pipeline stage. | |
numberc, errc, err := lineParser(ctx, 10, linec) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
// Sink pipeline stage. | |
errc, err = sink(ctx, numberc) | |
if err != nil { | |
return err | |
} | |
errcList = append(errcList, errc) | |
fmt.Println("Pipeline started. Waiting for pipeline to complete.") | |
// Start a goroutine that will cancel this pipeline in 10 seconds. | |
go func() { | |
time.Sleep(10 * time.Second) | |
fmt.Println("Cancelling context.") | |
cancelFunc() | |
}() | |
return WaitForPipeline(errcList...) | |
} | |
func main() { | |
if err := runSimplePipeline(10, []string{"3", "2", "1"}); err != nil { | |
fmt.Println(err) | |
} | |
if err := runSimplePipeline(1, []string{"3", "2", "1"}); err != nil { | |
fmt.Println(err) | |
} | |
if err := runSimplePipeline(2, []string{"1010", "1100", "1000"}); err != nil { | |
fmt.Println(err) | |
} | |
if err := runSimplePipeline(2, []string{"1010", "1100", "2000", "1111"}); err != nil { | |
fmt.Println(err) | |
} | |
if err := runSimplePipeline(10, []string{"1", "10", "100", "1000"}); err != nil { | |
fmt.Println(err) | |
} | |
if err := runComplexPipeline(10, []string{"5", "4", "3"}); err != nil { | |
fmt.Println(err) | |
} | |
if err := runPipelineWithTimeout(); err != nil { | |
fmt.Println(err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment