Skip to content

Instantly share code, notes, and snippets.

@moonsub-kim
Last active March 10, 2022 05:33
Show Gist options
  • Save moonsub-kim/9fb3eb3b72f2cafff51e6fa94ccaaaf3 to your computer and use it in GitHub Desktop.
Save moonsub-kim/9fb3eb3b72f2cafff51e6fa94ccaaaf3 to your computer and use it in GitHub Desktop.
Refresh sts token in background
package main
const (
envAWSEndpoint = "AWS_ENDPOINT"
envAWSRegion = "AWS_REGION"
)
func main() {
endPoint := os.Getenv(envAWSEndpoint)
region := os.Getenv(envAWSRegion)
// https://gist.github.com/moonsub-kim/6246802d85e16d56609da06af4083065#file-clienttrace_injector-go
c := httptrace.WrapClient(
&http.Client{},
serviceName,
ddhttptrace.WithBefore(func(r *http.Request, span ddtrace.Span) {
span.SetTag("url", r.URL)
}),
)
config := &aws.Config{
MaxRetries: aws.Int(dynamoMaxRetry),
Endpoint: aws.String(endPoint),
Region: aws.String(region),
LogLevel: aws.LogLevel(aws.LogOff),
HTTPClient: c,
}
session, err := session.NewSession(config)
if err != nil {
d.logger.Fatal("Failed to create aws session", zap.Error(err))
}
p, err := credentials.NewWebIdentityRoleProvider(session)
if err != nil {
d.logger.Warn("Use default credentials provider", zap.Error(err)) // local environment
} else {
cred := credentials.NewRefreshCredentials(d.logger, p, p, serviceName)
session.Config.WithCredentials(cred)
}
session = awstrace.WrapSession(session, awstrace.WithServiceName(serviceName))
dynamoDB := dynamodb.New(session, session.Config)
}
package main
import (
"context"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"go.uber.org/zap"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
const (
tickDuration = time.Minute * 10
)
// Expirer expires credentials.Value
type Expirer interface {
Expire()
}
// RefreshProvider definition
// refer: https://github.com/aws/aws-sdk-go/issues/561#issuecomment-185974563
type RefreshProvider struct {
credentials.ProviderWithContext
Ticker *time.Ticker
creds credentials.Value
err error
mux sync.RWMutex
initRunner sync.Once
Expirer Expirer
logger *zap.Logger
serviceName string
}
// Retrieve returns credentials
func (p *RefreshProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}
// RetrieveWithContext returns credentials
func (p *RefreshProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "RefreshProvider.RetrieveWithContext")
span.SetTag(ext.ServiceName, p.serviceName)
defer span.Finish()
p.initRunner.Do(func() {
p.creds, p.err = p.ProviderWithContext.RetrieveWithContext(ctx)
go p.periodicRefresh()
})
p.mux.RLock()
defer p.mux.RUnlock()
return p.creds, p.err
}
// IsExpired returns whether the credentials are no longer valid
func (p *RefreshProvider) IsExpired() bool {
p.mux.RLock()
defer p.mux.RUnlock()
return p.ProviderWithContext.IsExpired()
}
func (p *RefreshProvider) periodicRefresh() {
defer func() {
if v := recover(); v != nil {
p.logger.Error("periodicRefresh recovered", zap.Any("obj", v))
}
p.Ticker.Stop()
}()
for {
_, ok := <-p.Ticker.C
if !ok {
break
}
if p.refresh() {
// Expire() must be called on mutex unlocked.
// Calling Expire() with locking mutex cause of deadlock
// because credentials call Retrieve() after locking its mutex,
p.Expirer.Expire()
}
}
}
func (p *RefreshProvider) refresh() bool {
p.mux.Lock()
defer p.mux.Unlock()
// Probably want to log the returned error
creds, err := p.ProviderWithContext.Retrieve()
if err != nil {
if p.ProviderWithContext.IsExpired() {
p.err = err
p.logger.Error("Failed to refresh credentials", zap.Error(p.err))
return false
}
p.logger.Warn("Failed to refresh credentials. But current credentials are not yet expired", zap.Error(p.err))
}
p.err = nil
p.creds = creds
return true
}
// NewRefreshCredentials returns automatically refreshed credentials
func NewRefreshCredentials(logger *zap.Logger, provider credentials.ProviderWithContext, expirer credentials.Expirer, serviceName string) *credentials.Credentials {
ticker := time.NewTicker(tickDuration)
rp := &RefreshProvider{
ProviderWithContext: provider,
Ticker: ticker,
logger: logger,
serviceName: serviceName,
}
rp.Retrieve() // To run goroutine
creds := credentials.NewCredentials(rp)
rp.Expirer = creds
return creds
}
package main
import (
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/service/sts"
)
const (
expiryWindow = time.Minute * 10
envAWSRoleARN = "AWS_ROLE_ARN"
envAWSRoleSessionName = "AWS_ROLE_SESSION_NAME"
envAWSWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE"
)
// NewWebIdentityRoleProvider returns WebIdentityProvider
func NewWebIdentityRoleProvider(c client.ConfigProvider) (*stscreds.WebIdentityRoleProvider, error) {
awsRoleArn, ok := os.LookupEnv(envAWSRoleARN)
if !ok {
return nil, fmt.Errorf("emtpy env " + envAWSRoleARN)
}
awsWebIdentityTokenFile, ok := os.LookupEnv(envAWSWebIdentityTokenFile)
if !ok {
return nil, fmt.Errorf("emtpy env " + awsWebIdentityTokenFile)
}
awsRoleSessionName := os.Getenv(envAWSRoleSessionName) // optional field
webIdentityRoleProvider := stscreds.NewWebIdentityRoleProvider(
sts.New(c),
awsRoleArn,
awsRoleSessionName,
awsWebIdentityTokenFile,
)
webIdentityRoleProvider.ExpiryWindow = expiryWindow
_, err := webIdentityRoleProvider.Retrieve()
if err != nil {
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() != sts.ErrCodeInvalidIdentityTokenException {
return nil, err
}
}
return webIdentityRoleProvider, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment