Last active
September 26, 2024 12:09
-
-
Save nmvuong92/8bf7dc6962f3034da2c09a03be4ba19c to your computer and use it in GitHub Desktop.
golang s3 download with progressbar
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package awss3 | |
import ( | |
"context" | |
"crypto/tls" | |
"fmt" | |
"io" | |
"net/http" | |
"os" | |
"path/filepath" | |
"time" | |
"github.com/aws/aws-sdk-go/aws" | |
"github.com/aws/aws-sdk-go/aws/session" | |
"github.com/aws/aws-sdk-go/service/s3" | |
"github.com/aws/aws-sdk-go/service/s3/s3manager" | |
"github.com/schollz/progressbar/v3" | |
) | |
type S3Bucket struct { | |
BucketName string | |
PathPrefix string | |
Region string | |
session *session.Session | |
} | |
func NewS3Bucket(bucketName, pathPrefix, region string) (*S3Bucket, error) { | |
sess, err := session.NewSession(&aws.Config{ | |
Region: aws.String(region), | |
}) | |
if err != nil { | |
return nil, fmt.Errorf("failed to create session: %w", err) | |
} | |
return &S3Bucket{ | |
BucketName: bucketName, | |
PathPrefix: pathPrefix, | |
Region: region, | |
session: sess, | |
}, nil | |
} | |
func (s *S3Bucket) Download(ctx context.Context, s3filename, destPath string) error { | |
if err := os.RemoveAll(destPath); err != nil { | |
return fmt.Errorf("failed to remove existing file: %w", err) | |
} | |
destFile, err := os.Create(destPath) | |
if err != nil { | |
return fmt.Errorf("failed to create destination file: %w", err) | |
} | |
defer destFile.Close() | |
s3object, err := s.GetObject(ctx, filepath.Join(s.PathPrefix, s3filename)) | |
if err != nil { | |
return fmt.Errorf("failed to get S3 object: %w", err) | |
} | |
bar := progressbar.DefaultBytes( | |
*s3object.ContentLength, | |
"[info] downloading s3: "+s3filename, | |
) | |
downloader := s3manager.NewDownloader(s.session, func(d *s3manager.Downloader) { | |
d.PartSize = 64 * 1024 * 1024 | |
d.Concurrency = 6 | |
}) | |
_, err = downloader.DownloadWithContext(ctx, destFile, | |
&s3.GetObjectInput{ | |
Bucket: aws.String(s.BucketName), | |
Key: aws.String(filepath.Join(s.PathPrefix, s3filename)), | |
}) | |
if err != nil { | |
return fmt.Errorf("failed to download file: %w", err) | |
} | |
go s.updateProgressBar(ctx, bar, destPath) | |
return nil | |
} | |
func (s *S3Bucket) updateProgressBar(ctx context.Context, bar *progressbar.ProgressBar, filePath string) { | |
ticker := time.NewTicker(200 * time.Millisecond) | |
defer ticker.Stop() | |
for { | |
select { | |
case <-ctx.Done(): | |
return | |
case <-ticker.C: | |
info, err := os.Stat(filePath) | |
if err != nil { | |
continue | |
} | |
_ = bar.Set64(info.Size()) | |
} | |
} | |
} | |
func (s *S3Bucket) GetObject(ctx context.Context, objectKey string) (*s3.GetObjectOutput, error) { | |
client := &http.Client{ | |
Transport: &http.Transport{ | |
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, | |
}, | |
} | |
svc := s3.New(s.session, &aws.Config{ | |
HTTPClient: client, | |
}) | |
result, err := svc.GetObjectWithContext(ctx, &s3.GetObjectInput{ | |
Bucket: aws.String(s.BucketName), | |
Key: aws.String(objectKey), | |
}) | |
if err != nil { | |
return nil, fmt.Errorf("failed to get object: %w", err) | |
} | |
return result, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment