Skip to content

Instantly share code, notes, and snippets.

@nmvuong92
Last active September 26, 2024 12:09
Show Gist options
  • Save nmvuong92/8bf7dc6962f3034da2c09a03be4ba19c to your computer and use it in GitHub Desktop.
Save nmvuong92/8bf7dc6962f3034da2c09a03be4ba19c to your computer and use it in GitHub Desktop.
golang s3 download with progressbar
package awss3
import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/schollz/progressbar/v3"
)
type S3Bucket struct {
BucketName string
PathPrefix string
Region string
session *session.Session
}
func NewS3Bucket(bucketName, pathPrefix, region string) (*S3Bucket, error) {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(region),
})
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &S3Bucket{
BucketName: bucketName,
PathPrefix: pathPrefix,
Region: region,
session: sess,
}, nil
}
func (s *S3Bucket) Download(ctx context.Context, s3filename, destPath string) error {
if err := os.RemoveAll(destPath); err != nil {
return fmt.Errorf("failed to remove existing file: %w", err)
}
destFile, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
}
defer destFile.Close()
s3object, err := s.GetObject(ctx, filepath.Join(s.PathPrefix, s3filename))
if err != nil {
return fmt.Errorf("failed to get S3 object: %w", err)
}
bar := progressbar.DefaultBytes(
*s3object.ContentLength,
"[info] downloading s3: "+s3filename,
)
downloader := s3manager.NewDownloader(s.session, func(d *s3manager.Downloader) {
d.PartSize = 64 * 1024 * 1024
d.Concurrency = 6
})
_, err = downloader.DownloadWithContext(ctx, destFile,
&s3.GetObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(filepath.Join(s.PathPrefix, s3filename)),
})
if err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
go s.updateProgressBar(ctx, bar, destPath)
return nil
}
func (s *S3Bucket) updateProgressBar(ctx context.Context, bar *progressbar.ProgressBar, filePath string) {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
info, err := os.Stat(filePath)
if err != nil {
continue
}
_ = bar.Set64(info.Size())
}
}
}
func (s *S3Bucket) GetObject(ctx context.Context, objectKey string) (*s3.GetObjectOutput, error) {
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
svc := s3.New(s.session, &aws.Config{
HTTPClient: client,
})
result, err := svc.GetObjectWithContext(ctx, &s3.GetObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(objectKey),
})
if err != nil {
return nil, fmt.Errorf("failed to get object: %w", err)
}
return result, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment