Created
September 18, 2017 08:28
-
-
Save chrishowell/316057235d86d1f10a82a556136a4ddf to your computer and use it in GitHub Desktop.
Using STS with MFA and longer lived session credentials
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.amazonaws.auth.profile.internal.{BasicProfile, ProfileStaticCredentialsProvider} | |
import com.amazonaws.auth.profile.{ProfileCredentialsProvider, ProfilesConfigFile} | |
import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain} | |
import com.amazonaws.services.s3.AmazonS3ClientBuilder | |
import com.amazonaws.services.securitytoken.model.{AssumeRoleRequest, GetSessionTokenRequest} | |
import com.amazonaws.services.securitytoken.{AWSSecurityTokenServiceClient, AWSSecurityTokenServiceClientBuilder} | |
import scala.collection.JavaConversions._ | |
object STSTesting { | |
def main(args: Array[String]): Unit = { | |
val credentialsProvider = new ProfileMfaAwareCredentialsProvider("prero") | |
credentialsProvider.getCredentials | |
credentialsProvider.getCredentials | |
credentialsProvider.getCredentials | |
val s3 = AmazonS3ClientBuilder.standard().withCredentials(credentialsProvider).build() | |
s3.listBuckets().toList.foreach(println) | |
} | |
} | |
class SessionCredentialsAssumeRoleMfaAwareCredentialsProvider(profile: BasicProfile) extends AWSCredentialsProvider { | |
val _36_HOURS = 129600 | |
var sessionCredentials: Option[AWSCredentials] = None | |
var assumedRoleCredentials: Option[AWSCredentials] = None | |
override def refresh() = { | |
assumedRoleCredentials = Some(getAssumedRoleCredentials) | |
} | |
def refreshSessionCredentials() = { | |
sessionCredentials = Some(fetchTemporarySessionCredentials) // at this point we look in the filesystem for credentials | |
} | |
override def getCredentials = { | |
if (assumedRoleCredentials.isEmpty) refresh() | |
assumedRoleCredentials.get | |
} | |
def getAssumedRoleCredentials = { | |
if (sessionCredentials.isEmpty) refreshSessionCredentials() // or session credentials expired? | |
val assumeRoleReq = new AssumeRoleRequest() | |
.withRoleSessionName(roleSessionName) | |
.withRoleArn(profile.getRoleArn) | |
val assumeRoleResult = AWSSecurityTokenServiceClientBuilder | |
.standard() | |
.withCredentials(new AWSStaticCredentialsProvider(sessionCredentials.get)) | |
.build() | |
.assumeRole(assumeRoleReq) | |
val assumedRoleCredentials = assumeRoleResult.getCredentials | |
new BasicSessionCredentials( | |
assumedRoleCredentials.getAccessKeyId, | |
assumedRoleCredentials.getSecretAccessKey, | |
assumedRoleCredentials.getSessionToken | |
) | |
} | |
private def fetchTemporarySessionCredentials = { | |
val stsClient = AWSSecurityTokenServiceClient | |
.builder() | |
.withCredentials(sourceCredentialsProvider) | |
.build() | |
val sessionToken = stsClient.getSessionToken(sessionRequest) | |
val sessionCredentials = sessionToken.getCredentials | |
new BasicSessionCredentials( | |
sessionCredentials.getAccessKeyId, | |
sessionCredentials.getSecretAccessKey, | |
sessionCredentials.getSessionToken | |
) | |
} | |
private def roleSessionName = { | |
Option.apply(profile.getRoleSessionName).getOrElse("session-" + System.currentTimeMillis()) | |
} | |
def mfaDevice: Option[String] = { | |
Option.apply(profile.getPropertyValue("mfa_serial")) | |
} | |
def mfaToken: String = { | |
val readMfaToken = scala.io.StdIn.readLine("MFA Token: ") | |
// validate | |
readMfaToken | |
} | |
private def sessionRequest = { | |
val sessionRequest = new GetSessionTokenRequest | |
if(mfaDevice.isDefined) { | |
sessionRequest.setSerialNumber(mfaDevice.get) | |
sessionRequest.setTokenCode(mfaToken) | |
} | |
sessionRequest.setDurationSeconds(_36_HOURS) | |
sessionRequest | |
} | |
private def sourceCredentialsProvider = { | |
Option.apply(profile.getRoleSourceProfile).map(new ProfileCredentialsProvider(_)).getOrElse(new DefaultAWSCredentialsProviderChain()) | |
} | |
} | |
class ProfileMfaAwareCredentialsProvider(profile: BasicProfile) extends AWSCredentialsProvider { | |
def this(profileName: String) { | |
this(new ProfilesConfigFile().getAllBasicProfiles.get(profileName)) | |
} | |
val credentialsProvider: AWSCredentialsProvider = | |
if (profile.isRoleBasedProfile) | |
new SessionCredentialsAssumeRoleMfaAwareCredentialsProvider(profile) | |
else | |
new ProfileStaticCredentialsProvider(profile) | |
override def refresh() = credentialsProvider.refresh() | |
override def getCredentials = credentialsProvider.getCredentials | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment