Last active
January 13, 2019 16:46
-
-
Save yottta/bb546a3719bc04ffe0072b68a83ec446 to your computer and use it in GitHub Desktop.
Stream files from an S3 bucket
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"net" | |
"net/http" | |
"os" | |
"sync" | |
"time" | |
"github.com/aws/aws-sdk-go/aws" | |
"github.com/aws/aws-sdk-go/aws/credentials" | |
"github.com/aws/aws-sdk-go/aws/session" | |
"github.com/aws/aws-sdk-go/service/s3" | |
"github.com/aws/aws-sdk-go/service/s3/s3manager" | |
) | |
type s3Provider struct { | |
bucketRegion string | |
bucketName string | |
dataStream chan []byte | |
s3Session *session.Session | |
httpClient *http.Client | |
downloadCancelFunc context.CancelFunc | |
downloadCancelFuncUpdateMux *sync.Mutex | |
stopChan chan bool | |
stopAckChan chan bool | |
stopped bool | |
} | |
func newS3Provider(bucketName string, bucketRegion string) *s3Provider { | |
hc := &http.Client{ | |
Transport: &http.Transport{ | |
Proxy: http.ProxyFromEnvironment, | |
DialContext: (&net.Dialer{ | |
Timeout: 30 * time.Second, | |
KeepAlive: 30 * time.Second, | |
}).DialContext, | |
IdleConnTimeout: 90 * time.Second, | |
TLSHandshakeTimeout: 3 * time.Second, | |
ExpectContinueTimeout: 1 * time.Second, | |
}, | |
} | |
creds := credentials.NewEnvCredentials() | |
sess, _ := session.NewSession(&aws.Config{ | |
Region: aws.String(bucketRegion), | |
Credentials: creds, | |
HTTPClient: hc, | |
}) | |
return &s3Provider{ | |
bucketName: bucketName, | |
bucketRegion: bucketRegion, | |
dataStream: make(chan []byte), | |
// aws related | |
httpClient: hc, | |
s3Session: sess, | |
// for being able to stop the download if this is requested | |
downloadCancelFuncUpdateMux: &sync.Mutex{}, | |
stopChan: make(chan bool), | |
stopAckChan: make(chan bool), | |
} | |
} | |
func (s *s3Provider) Start() error { | |
defer func() { | |
close(s.dataStream) | |
close(s.stopChan) | |
s.stopped = true | |
close(s.stopAckChan) | |
}() | |
// list s3 files | |
files, err := s.getFiles() | |
if err != nil { | |
return err | |
} | |
for _, currentFile := range files { | |
if s.stopped { | |
return nil | |
} | |
downloader := s3manager.NewDownloader(s.s3Session) | |
// This is the catch for streaming files' content. If you want to process as it comes, set Concurrency to 1. | |
// Otherwise the multipart download will come into action removing the guarantee of the bytes order | |
downloader.Concurrency = 1 | |
// prepare for canceling if needed | |
context, cancelFunc := context.WithCancel(aws.BackgroundContext()) | |
s.setCancelFunc(cancelFunc) | |
inMemoryBuf := &inMemoryBuffer{ | |
dataStream: s.dataStream, | |
mux: &sync.Mutex{}, | |
} | |
_, err := downloader.DownloadWithContext(context, | |
inMemoryBuf, | |
&s3.GetObjectInput{ | |
Bucket: aws.String(s.bucketName), | |
Key: currentFile.Key, | |
}, | |
) | |
if err != nil { | |
// log error | |
continue | |
} | |
} | |
return nil | |
} | |
func (s *s3Provider) setCancelFunc(cFunc context.CancelFunc) { | |
s.downloadCancelFuncUpdateMux.Lock() | |
defer s.downloadCancelFuncUpdateMux.Unlock() | |
s.downloadCancelFunc = cFunc | |
} | |
func (s *s3Provider) callCancelFunc() { | |
s.downloadCancelFuncUpdateMux.Lock() | |
defer s.downloadCancelFuncUpdateMux.Unlock() | |
if s.downloadCancelFunc != nil { | |
s.downloadCancelFunc() | |
} | |
s.stopped = true | |
} | |
func (s *s3Provider) getFiles() ([]*s3.Object, error) { | |
svc := s3.New(s.s3Session) | |
results := []*s3.Object{} | |
err := svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ | |
Bucket: aws.String(s.bucketName), | |
// Prefix: aws.String(), // use this if the files are not in root of the bucket | |
}, func(res *s3.ListObjectsV2Output, ok bool) bool { | |
for _, content := range res.Contents { | |
results = append(results, content) | |
} | |
return len(res.Contents) >= 1000 // because ListObjectsV2Pages brings only up to 1k entries/call | |
}) | |
return results, err | |
} | |
func (s *s3Provider) Stop() { | |
s.callCancelFunc() | |
} | |
func (s *s3Provider) DataStream() <-chan []byte { | |
return s.dataStream | |
} | |
type inMemoryBuffer struct { | |
mux *sync.Mutex | |
dataStream chan []byte | |
} | |
func (c *inMemoryBuffer) WriteAt(p []byte, pos int64) (n int, err error) { | |
c.mux.Lock() | |
defer c.mux.Unlock() | |
c.dataStream <- p | |
return len(p), nil | |
} | |
// Requirements: | |
// Env variables: | |
// AWS_SECRET_KEY | |
// AWS_ACCESS_KEY | |
// Arguments: | |
// 1) bucketName | |
// 2) bucketRegion | |
func main() { | |
if len(os.Args) < 3 { | |
log.Fatal("You have to pass the required arguments. Usage: go run main.go <bucketName> <bucketRegion>") | |
os.Exit(1) | |
} | |
bucketName := os.Args[1] | |
bucketRegion := os.Args[2] | |
service := newS3Provider(bucketName, bucketRegion) | |
go service.Start() | |
for data := range service.DataStream() { | |
fmt.Printf("%s\n", data) // do whatever you want with the data as it comes | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment