Skip to content

Instantly share code, notes, and snippets.

@yottta
Last active January 13, 2019 16:46
Show Gist options
  • Save yottta/bb546a3719bc04ffe0072b68a83ec446 to your computer and use it in GitHub Desktop.
Save yottta/bb546a3719bc04ffe0072b68a83ec446 to your computer and use it in GitHub Desktop.
Stream files from an S3 bucket
package main
import (
"context"
"fmt"
"log"
"net"
"net/http"
"os"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
type s3Provider struct {
bucketRegion string
bucketName string
dataStream chan []byte
s3Session *session.Session
httpClient *http.Client
downloadCancelFunc context.CancelFunc
downloadCancelFuncUpdateMux *sync.Mutex
stopChan chan bool
stopAckChan chan bool
stopped bool
}
func newS3Provider(bucketName string, bucketRegion string) *s3Provider {
hc := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 3 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
creds := credentials.NewEnvCredentials()
sess, _ := session.NewSession(&aws.Config{
Region: aws.String(bucketRegion),
Credentials: creds,
HTTPClient: hc,
})
return &s3Provider{
bucketName: bucketName,
bucketRegion: bucketRegion,
dataStream: make(chan []byte),
// aws related
httpClient: hc,
s3Session: sess,
// for being able to stop the download if this is requested
downloadCancelFuncUpdateMux: &sync.Mutex{},
stopChan: make(chan bool),
stopAckChan: make(chan bool),
}
}
func (s *s3Provider) Start() error {
defer func() {
close(s.dataStream)
close(s.stopChan)
s.stopped = true
close(s.stopAckChan)
}()
// list s3 files
files, err := s.getFiles()
if err != nil {
return err
}
for _, currentFile := range files {
if s.stopped {
return nil
}
downloader := s3manager.NewDownloader(s.s3Session)
// This is the catch for streaming files' content. If you want to process as it comes, set Concurrency to 1.
// Otherwise the multipart download will come into action removing the guarantee of the bytes order
downloader.Concurrency = 1
// prepare for canceling if needed
context, cancelFunc := context.WithCancel(aws.BackgroundContext())
s.setCancelFunc(cancelFunc)
inMemoryBuf := &inMemoryBuffer{
dataStream: s.dataStream,
mux: &sync.Mutex{},
}
_, err := downloader.DownloadWithContext(context,
inMemoryBuf,
&s3.GetObjectInput{
Bucket: aws.String(s.bucketName),
Key: currentFile.Key,
},
)
if err != nil {
// log error
continue
}
}
return nil
}
func (s *s3Provider) setCancelFunc(cFunc context.CancelFunc) {
s.downloadCancelFuncUpdateMux.Lock()
defer s.downloadCancelFuncUpdateMux.Unlock()
s.downloadCancelFunc = cFunc
}
func (s *s3Provider) callCancelFunc() {
s.downloadCancelFuncUpdateMux.Lock()
defer s.downloadCancelFuncUpdateMux.Unlock()
if s.downloadCancelFunc != nil {
s.downloadCancelFunc()
}
s.stopped = true
}
func (s *s3Provider) getFiles() ([]*s3.Object, error) {
svc := s3.New(s.s3Session)
results := []*s3.Object{}
err := svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
Bucket: aws.String(s.bucketName),
// Prefix: aws.String(), // use this if the files are not in root of the bucket
}, func(res *s3.ListObjectsV2Output, ok bool) bool {
for _, content := range res.Contents {
results = append(results, content)
}
return len(res.Contents) >= 1000 // because ListObjectsV2Pages brings only up to 1k entries/call
})
return results, err
}
func (s *s3Provider) Stop() {
s.callCancelFunc()
}
func (s *s3Provider) DataStream() <-chan []byte {
return s.dataStream
}
type inMemoryBuffer struct {
mux *sync.Mutex
dataStream chan []byte
}
func (c *inMemoryBuffer) WriteAt(p []byte, pos int64) (n int, err error) {
c.mux.Lock()
defer c.mux.Unlock()
c.dataStream <- p
return len(p), nil
}
// Requirements:
// Env variables:
// AWS_SECRET_KEY
// AWS_ACCESS_KEY
// Arguments:
// 1) bucketName
// 2) bucketRegion
func main() {
if len(os.Args) < 3 {
log.Fatal("You have to pass the required arguments. Usage: go run main.go <bucketName> <bucketRegion>")
os.Exit(1)
}
bucketName := os.Args[1]
bucketRegion := os.Args[2]
service := newS3Provider(bucketName, bucketRegion)
go service.Start()
for data := range service.DataStream() {
fmt.Printf("%s\n", data) // do whatever you want with the data as it comes
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment