Last active
March 10, 2022 05:33
-
-
Save moonsub-kim/9fb3eb3b72f2cafff51e6fa94ccaaaf3 to your computer and use it in GitHub Desktop.
Refresh sts token in background
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
const ( | |
envAWSEndpoint = "AWS_ENDPOINT" | |
envAWSRegion = "AWS_REGION" | |
) | |
func main() { | |
endPoint := os.Getenv(envAWSEndpoint) | |
region := os.Getenv(envAWSRegion) | |
// https://gist.github.com/moonsub-kim/6246802d85e16d56609da06af4083065#file-clienttrace_injector-go | |
c := httptrace.WrapClient( | |
&http.Client{}, | |
serviceName, | |
ddhttptrace.WithBefore(func(r *http.Request, span ddtrace.Span) { | |
span.SetTag("url", r.URL) | |
}), | |
) | |
config := &aws.Config{ | |
MaxRetries: aws.Int(dynamoMaxRetry), | |
Endpoint: aws.String(endPoint), | |
Region: aws.String(region), | |
LogLevel: aws.LogLevel(aws.LogOff), | |
HTTPClient: c, | |
} | |
session, err := session.NewSession(config) | |
if err != nil { | |
d.logger.Fatal("Failed to create aws session", zap.Error(err)) | |
} | |
p, err := credentials.NewWebIdentityRoleProvider(session) | |
if err != nil { | |
d.logger.Warn("Use default credentials provider", zap.Error(err)) // local environment | |
} else { | |
cred := credentials.NewRefreshCredentials(d.logger, p, p, serviceName) | |
session.Config.WithCredentials(cred) | |
} | |
session = awstrace.WrapSession(session, awstrace.WithServiceName(serviceName)) | |
dynamoDB := dynamodb.New(session, session.Config) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"sync" | |
"time" | |
"github.com/aws/aws-sdk-go/aws/credentials" | |
"go.uber.org/zap" | |
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" | |
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" | |
) | |
const ( | |
tickDuration = time.Minute * 10 | |
) | |
// Expirer expires credentials.Value | |
type Expirer interface { | |
Expire() | |
} | |
// RefreshProvider definition | |
// refer: https://github.com/aws/aws-sdk-go/issues/561#issuecomment-185974563 | |
type RefreshProvider struct { | |
credentials.ProviderWithContext | |
Ticker *time.Ticker | |
creds credentials.Value | |
err error | |
mux sync.RWMutex | |
initRunner sync.Once | |
Expirer Expirer | |
logger *zap.Logger | |
serviceName string | |
} | |
// Retrieve returns credentials | |
func (p *RefreshProvider) Retrieve() (credentials.Value, error) { | |
return p.RetrieveWithContext(context.Background()) | |
} | |
// RetrieveWithContext returns credentials | |
func (p *RefreshProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { | |
span, ctx := tracer.StartSpanFromContext(ctx, "RefreshProvider.RetrieveWithContext") | |
span.SetTag(ext.ServiceName, p.serviceName) | |
defer span.Finish() | |
p.initRunner.Do(func() { | |
p.creds, p.err = p.ProviderWithContext.RetrieveWithContext(ctx) | |
go p.periodicRefresh() | |
}) | |
p.mux.RLock() | |
defer p.mux.RUnlock() | |
return p.creds, p.err | |
} | |
// IsExpired returns whether the credentials are no longer valid | |
func (p *RefreshProvider) IsExpired() bool { | |
p.mux.RLock() | |
defer p.mux.RUnlock() | |
return p.ProviderWithContext.IsExpired() | |
} | |
func (p *RefreshProvider) periodicRefresh() { | |
defer func() { | |
if v := recover(); v != nil { | |
p.logger.Error("periodicRefresh recovered", zap.Any("obj", v)) | |
} | |
p.Ticker.Stop() | |
}() | |
for { | |
_, ok := <-p.Ticker.C | |
if !ok { | |
break | |
} | |
if p.refresh() { | |
// Expire() must be called on mutex unlocked. | |
// Calling Expire() with locking mutex cause of deadlock | |
// because credentials call Retrieve() after locking its mutex, | |
p.Expirer.Expire() | |
} | |
} | |
} | |
func (p *RefreshProvider) refresh() bool { | |
p.mux.Lock() | |
defer p.mux.Unlock() | |
// Probably want to log the returned error | |
creds, err := p.ProviderWithContext.Retrieve() | |
if err != nil { | |
if p.ProviderWithContext.IsExpired() { | |
p.err = err | |
p.logger.Error("Failed to refresh credentials", zap.Error(p.err)) | |
return false | |
} | |
p.logger.Warn("Failed to refresh credentials. But current credentials are not yet expired", zap.Error(p.err)) | |
} | |
p.err = nil | |
p.creds = creds | |
return true | |
} | |
// NewRefreshCredentials returns automatically refreshed credentials | |
func NewRefreshCredentials(logger *zap.Logger, provider credentials.ProviderWithContext, expirer credentials.Expirer, serviceName string) *credentials.Credentials { | |
ticker := time.NewTicker(tickDuration) | |
rp := &RefreshProvider{ | |
ProviderWithContext: provider, | |
Ticker: ticker, | |
logger: logger, | |
serviceName: serviceName, | |
} | |
rp.Retrieve() // To run goroutine | |
creds := credentials.NewCredentials(rp) | |
rp.Expirer = creds | |
return creds | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"os" | |
"time" | |
"github.com/aws/aws-sdk-go/aws/awserr" | |
"github.com/aws/aws-sdk-go/aws/client" | |
"github.com/aws/aws-sdk-go/aws/credentials/stscreds" | |
"github.com/aws/aws-sdk-go/service/sts" | |
) | |
const ( | |
expiryWindow = time.Minute * 10 | |
envAWSRoleARN = "AWS_ROLE_ARN" | |
envAWSRoleSessionName = "AWS_ROLE_SESSION_NAME" | |
envAWSWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE" | |
) | |
// NewWebIdentityRoleProvider returns WebIdentityProvider | |
func NewWebIdentityRoleProvider(c client.ConfigProvider) (*stscreds.WebIdentityRoleProvider, error) { | |
awsRoleArn, ok := os.LookupEnv(envAWSRoleARN) | |
if !ok { | |
return nil, fmt.Errorf("emtpy env " + envAWSRoleARN) | |
} | |
awsWebIdentityTokenFile, ok := os.LookupEnv(envAWSWebIdentityTokenFile) | |
if !ok { | |
return nil, fmt.Errorf("emtpy env " + awsWebIdentityTokenFile) | |
} | |
awsRoleSessionName := os.Getenv(envAWSRoleSessionName) // optional field | |
webIdentityRoleProvider := stscreds.NewWebIdentityRoleProvider( | |
sts.New(c), | |
awsRoleArn, | |
awsRoleSessionName, | |
awsWebIdentityTokenFile, | |
) | |
webIdentityRoleProvider.ExpiryWindow = expiryWindow | |
_, err := webIdentityRoleProvider.Retrieve() | |
if err != nil { | |
// InvalidIdentityToken error is a temporary error that can occur | |
// when assuming an Role with a JWT web identity token. | |
awsErr, ok := err.(awserr.Error) | |
if ok && awsErr.Code() != sts.ErrCodeInvalidIdentityTokenException { | |
return nil, err | |
} | |
} | |
return webIdentityRoleProvider, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment