Created
October 4, 2023 14:50
-
-
Save keymon/e9b91d3644254bbc8dc587ca57d9eece to your computer and use it in GitHub Desktop.
A funciton in golang that updates a temporary pgfile and sets PGPASSFILE env var. Can be updated to use pgpassfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package pgfile | |
import ( | |
"context" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"os" | |
"strings" | |
"time" | |
"github.com/aws/aws-sdk-go-v2/aws" | |
"github.com/aws/aws-sdk-go-v2/credentials/stscreds" | |
"github.com/aws/aws-sdk-go-v2/feature/rds/auth" | |
"github.com/aws/aws-sdk-go-v2/service/sts" | |
"github.com/google/renameio" | |
) | |
// This function does export the variable PGPASSFILE and starts a goroutine | |
// that will update it every 5 minutes with new creds from RDS IAM of AWS, using | |
// IRSA (webidentitytoken) to auth in AWS | |
// | |
// WARNING: set an env var, and can be only use one per program | |
func StartPgPassFileRdsIamUpdater( | |
ctx context.Context, | |
rdsHost string, | |
rdsPort string, | |
rdsRegion string, | |
rdsUser string, | |
rdsRoleArn string, | |
rdsWebIdentityTokenFile string, | |
) error { | |
file, err := ioutil.TempFile("tmp", "prefix") | |
if err != nil { | |
return err | |
} | |
cfg := aws.Config{ | |
Region: rdsRegion, | |
} | |
stsClient := sts.NewFromConfig(cfg) | |
credsProvider := stscreds.NewWebIdentityRoleProvider( | |
stsClient, rdsRoleArn, stscreds.IdentityTokenFile(rdsWebIdentityTokenFile), | |
) | |
updateCreds := func() error { | |
password, err := auth.BuildAuthToken( | |
ctx, | |
fmt.Sprintf("%s:%s", rdsHost, rdsPort), // database endpoint (with port) | |
rdsRegion, | |
rdsUser, | |
credsProvider, | |
) | |
if err != nil { | |
return err | |
} | |
escapedPassword := strings.Replace(password, ":", "\\:", -1) | |
pgPassFileContent := fmt.Sprintf("%s:%s:%s:%s:%s", rdsHost, rdsPort, "*", rdsUser, escapedPassword) | |
// Atomically write the file using https://pkg.go.dev/github.com/google/renameio#WriteFile | |
// that uses rename | |
err = renameio.WriteFile(file.Name(), []byte(pgPassFileContent), 0600) | |
if err != nil { | |
return err | |
} | |
fmt.Println(pgPassFileContent) | |
return err | |
} | |
// First update | |
fmt.Println("first cred update") | |
err = updateCreds() | |
if err != nil { | |
return err | |
} | |
// Export the PGPASSFILE variable | |
// WARNING!!! this is global!!! | |
os.Setenv("PGPASSFILE", file.Name()) | |
// Update the file every 5 minutes | |
ticker := time.NewTicker(300 * time.Second) | |
go func() { | |
for { | |
select { | |
case <-ctx.Done(): | |
os.Remove(file.Name()) | |
return | |
case _ = <-ticker.C: | |
err := updateCreds() | |
// TODO: Would be cool to add better logging and also some retry logic | |
if err != nil { | |
log.Printf("Warning: %s", err) | |
} | |
} | |
} | |
}() | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment