Last active
January 6, 2021 08:55
-
-
Save moonsub-kim/1e75666ae1a1678ecb68467276d24060 to your computer and use it in GitHub Desktop.
refresh credentials profvider
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"sync" | |
"time" | |
"github.com/aws/aws-sdk-go/aws/credentials" | |
) | |
const ( | |
tickDuration = time.Minute * 5 | |
) | |
// Expirer expires credentials.Value | |
type Expirer interface { | |
Expire() | |
} | |
// RefreshProvider definition | |
// refer: https://github.com/aws/aws-sdk-go/issues/561#issuecomment-185974563 | |
type RefreshProvider struct { | |
credentials.ProviderWithContext | |
Ticker *time.Ticker | |
creds credentials.Value | |
err error | |
mux sync.RWMutex | |
initRunner sync.Once | |
Expirer Expirer | |
} | |
// Retrieve returns credentials | |
func (p *RefreshProvider) Retrieve() (credentials.Value, error) { | |
return p.RetrieveWithContext(context.Background()) | |
} | |
// RetrieveWithContext returns credentials | |
func (p *RefreshProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { | |
p.initRunner.Do(func() { | |
p.creds, p.err = p.ProviderWithContext.RetrieveWithContext(ctx) | |
go func() { | |
defer func() { | |
if v := recover(); v != nil { | |
// p.logger.Error( | |
// "periodicRefresh recovered", | |
// zap.Any("obj", v), | |
// ) | |
} | |
p.Ticker.Stop() | |
}() | |
p.periodicRefresh() | |
}() | |
}) | |
p.mux.RLock() | |
defer p.mux.RUnlock() | |
return p.creds, p.err | |
} | |
// IsExpired returns whether the credentials are no longer valid | |
func (p *RefreshProvider) IsExpired() bool { | |
p.mux.RLock() | |
defer p.mux.RUnlock() | |
return p.ProviderWithContext.IsExpired() | |
} | |
func (p *RefreshProvider) periodicRefresh() { | |
for { | |
_, ok := <-p.Ticker.C | |
if !ok { | |
break | |
} | |
if p.refresh() { | |
// Expire() must be called on mutex unlocked. | |
// Calling Expire() with locking mutex cause of deadlock | |
// because credentials call Retrieve() after locking its mutex, | |
p.Expirer.Expire() | |
} | |
} | |
} | |
func (p *RefreshProvider) refresh() bool { | |
p.mux.Lock() | |
defer p.mux.Unlock() | |
// Probably want to log the returned error | |
creds, err := p.ProviderWithContext.Retrieve() | |
if err != nil { | |
if p.ProviderWithContext.IsExpired() { | |
p.err = err | |
return false | |
} | |
} | |
p.err = nil | |
p.creds = creds | |
return true | |
} | |
// NewRefreshCredentials returns automatically refreshed credentials | |
func NewRefreshCredentials( | |
provider credentials.ProviderWithContext, | |
expirer credentials.Expirer, | |
) *credentials.Credentials { | |
ticker := time.NewTicker(tickDuration) | |
rp := &RefreshProvider{ | |
ProviderWithContext: provider, | |
Ticker: ticker, | |
} | |
rp.Retrieve() // To run goroutine | |
creds := credentials.NewCredentials(rp) | |
rp.Expirer = creds | |
return creds | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment