Skip to content

Instantly share code, notes, and snippets.

@claudiofahey
Last active November 13, 2023 19:29
Show Gist options
  • Save claudiofahey/3afcf4f4fb3d8d3b35cadb100d4fb9b7 to your computer and use it in GitHub Desktop.
Save claudiofahey/3afcf4f4fb3d8d3b35cadb100d4fb9b7 to your computer and use it in GitHub Desktop.
//
// Companion code to https://medium.com/statuscode/pipeline-patterns-in-go-a37bb3a7e61d
//
// To run:
// go get github.com/pkg/errors
// go run -race pipeline_demo.go
//
package main
import (
"context"
"fmt"
"math/rand"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
)
// MergeErrors merges multiple channels of errors.
// Based on https://blog.golang.org/pipelines.
func MergeErrors(cs ...<-chan error) <-chan error {
var wg sync.WaitGroup
// We must ensure that the output channel has the capacity to hold as many errors
// as there are error channels. This will ensure that it never blocks, even
// if WaitForPipeline returns early.
out := make(chan error, len(cs))
// Start an output goroutine for each input channel in cs. output
// copies values from c to out until c is closed, then calls wg.Done.
output := func(c <-chan error) {
for n := range c {
out <- n
}
wg.Done()
}
wg.Add(len(cs))
for _, c := range cs {
go output(c)
}
// Start a goroutine to close out once all the output goroutines are
// done. This must start after the wg.Add call.
go func() {
wg.Wait()
close(out)
}()
return out
}
// WaitForPipeline waits for results from all error channels.
// It returns early on the first error.
func WaitForPipeline(errs ...<-chan error) error {
errc := MergeErrors(errs...)
for err := range errc {
if err != nil {
return err
}
}
return nil
}
// minimalPipelineStage shows the elements that every pipeline stage should have.
// All stages should accept a context for cancellation.
// All stages should return a channel of errors to report any error produced after this function returns.
// All stages should return an error to report any error produced before this function returns.
// Any required input parameters should follow ctx and any required outputs should precede
// the errors channel.
// Inputs can be ordinary objects (e.g. a list of strings), channels of objects, or gRPC input streams.
// Outputs can be ordinary objects, channels of objects, or gRPC output streams.
func minimalPipelineStage(ctx context.Context) (<-chan error, error) {
errc := make(chan error, 1)
go func() {
defer close(errc)
// Do something useful here.
}()
return errc, nil
}
func lineListSource(ctx context.Context, lines ...string) (
<-chan string, <-chan error, error) {
if len(lines) == 0 {
// Handle an error that occurs before the goroutine begins.
return nil, nil, errors.Errorf("no lines provided")
}
out := make(chan string)
errc := make(chan error, 1)
go func() {
defer close(out)
defer close(errc)
for lineIndex, line := range lines {
if line == "" {
// Handle an error that occurs during the goroutine.
errc <- errors.Errorf("line %v is empty", lineIndex+1)
return
}
// Send the data to the output channel but return early
// if the context has been cancelled.
select {
case out <- line:
case <-ctx.Done():
return
}
}
}()
return out, errc, nil
}
func lineParser(ctx context.Context, base int, in <-chan string) (
<-chan int64, <-chan error, error) {
if base < 2 {
// Handle an error that occurs before the goroutine begins.
return nil, nil, errors.Errorf("invalid base %v", base)
}
out := make(chan int64)
errc := make(chan error, 1)
go func() {
defer close(out)
defer close(errc)
for line := range in {
n, err := strconv.ParseInt(line, base, 64)
if err != nil {
// Handle an error that occurs during the goroutine.
errc <- err
return
}
// Send the data to the output channel but return early
// if the context has been cancelled.
select {
case out <- n:
case <-ctx.Done():
return
}
}
}()
return out, errc, nil
}
func splitter(ctx context.Context, in <-chan int64) (
<-chan int64, <-chan int64, <-chan error, error) {
out1 := make(chan int64)
out2 := make(chan int64)
errc := make(chan error, 1)
go func() {
defer close(out1)
defer close(out2)
defer close(errc)
for n := range in {
// Send the data to the output channel 1 but return early
// if the context has been cancelled.
select {
case out1 <- n:
case <-ctx.Done():
return
}
// Send the data to the output channel 2 but return early
// if the context has been cancelled.
select {
case out2 <- n:
case <-ctx.Done():
return
}
}
}()
return out1, out2, errc, nil
}
func squarer(ctx context.Context, in <-chan int64) (<-chan int64, <-chan error, error) {
out := make(chan int64)
errc := make(chan error, 1)
go func() {
defer close(out)
defer close(errc)
for n := range in {
// Send the data to the output channel but return early
// if the context has been cancelled.
select {
case out <- n * n:
case <-ctx.Done():
return
}
}
}()
return out, errc, nil
}
func sink(ctx context.Context, in <-chan int64) (
<-chan error, error) {
errc := make(chan error, 1)
go func() {
defer close(errc)
for n := range in {
if n >= 100 {
// Handle an error that occurs during the goroutine.
errc <- errors.Errorf("number %v is too large", n)
return
}
fmt.Printf("sink: %v\n", n)
}
}()
return errc, nil
}
func runSimplePipeline(base int, lines []string) error {
fmt.Printf("runSimplePipeline: base=%v, lines=%v\n", base, lines)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var errcList []<-chan error
// Source pipeline stage.
linec, errc, err := lineListSource(ctx, lines...)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Transformer pipeline stage.
numberc, errc, err := lineParser(ctx, base, linec)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Sink pipeline stage.
errc, err = sink(ctx, numberc)
if err != nil {
return err
}
errcList = append(errcList, errc)
fmt.Println("Pipeline started. Waiting for pipeline to complete.")
return WaitForPipeline(errcList...)
}
func runComplexPipeline(base int, lines []string) error {
fmt.Printf("runComplexPipeline: base=%v, lines=%v\n", base, lines)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var errcList []<-chan error
// Source pipeline stage.
linec, errc, err := lineListSource(ctx, lines...)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Transformer pipeline stage 1.
numberc, errc, err := lineParser(ctx, base, linec)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Transformer pipeline stage 2.
numberc1, numberc2, errc, err := splitter(ctx, numberc)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Transformer pipeline stage 3.
numberc3, errc, err := squarer(ctx, numberc1)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Sink pipeline stage 1.
errc, err = sink(ctx, numberc3)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Sink pipeline stage 2.
errc, err = sink(ctx, numberc2)
if err != nil {
return err
}
errcList = append(errcList, errc)
fmt.Println("Pipeline started. Waiting for pipeline to complete.")
return WaitForPipeline(errcList...)
}
func randomNumberSource(ctx context.Context, seed int64) (<-chan string, <-chan error, error) {
out := make(chan string)
errc := make(chan error, 1)
random := rand.New(rand.NewSource(seed))
go func() {
defer close(out)
defer close(errc)
for {
n := random.Intn(100)
line := fmt.Sprintf("%v", n)
// Send the data to the output channel but return if the context has been cancelled.
select {
case out <- line:
case <-ctx.Done():
return
}
time.Sleep(1 * time.Second)
}
}()
return out, errc, nil
}
func runPipelineWithTimeout() error {
fmt.Printf("runPipelineWithTimeout\n")
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
var errcList []<-chan error
// Source pipeline stage.
linec, errc, err := randomNumberSource(ctx, 3)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Transformer pipeline stage.
numberc, errc, err := lineParser(ctx, 10, linec)
if err != nil {
return err
}
errcList = append(errcList, errc)
// Sink pipeline stage.
errc, err = sink(ctx, numberc)
if err != nil {
return err
}
errcList = append(errcList, errc)
fmt.Println("Pipeline started. Waiting for pipeline to complete.")
// Start a goroutine that will cancel this pipeline in 10 seconds.
go func() {
time.Sleep(10 * time.Second)
fmt.Println("Cancelling context.")
cancelFunc()
}()
return WaitForPipeline(errcList...)
}
func main() {
if err := runSimplePipeline(10, []string{"3", "2", "1"}); err != nil {
fmt.Println(err)
}
if err := runSimplePipeline(1, []string{"3", "2", "1"}); err != nil {
fmt.Println(err)
}
if err := runSimplePipeline(2, []string{"1010", "1100", "1000"}); err != nil {
fmt.Println(err)
}
if err := runSimplePipeline(2, []string{"1010", "1100", "2000", "1111"}); err != nil {
fmt.Println(err)
}
if err := runSimplePipeline(10, []string{"1", "10", "100", "1000"}); err != nil {
fmt.Println(err)
}
if err := runComplexPipeline(10, []string{"5", "4", "3"}); err != nil {
fmt.Println(err)
}
if err := runPipelineWithTimeout(); err != nil {
fmt.Println(err)
}
}
@ymgyt
Copy link

ymgyt commented Oct 5, 2017

Thanks nice demo !

@claudiofahey
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment