Created
July 2, 2011 13:11
-
-
Save ericacm/1060126 to your computer and use it in GitHub Desktop.
Security OAuth 2 Provider and JPA Domain Classes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ProviderTokenServices(dataSource: DataSource) extends JdbcOAuth2ProviderTokenServices(dataSource) with InitializingBean { | |
var log: Logger = LoggerFactory.getLogger(this.getClass.getName) | |
@Autowired var domainManager: DomainManager = _ | |
@Autowired var schedulerService: SchedulerService = _ | |
@Value("${providerTokenServices.pruneSchedule:0 0 12 * * *}") | |
var pruneSchedule: String = _ | |
override def storeAccessToken(token: OAuth2AccessToken, authentication: OAuth2Authentication[_ <: Authentication, _ <: Authentication]) { | |
val refreshToken = if (token.getRefreshToken != null) token.getRefreshToken.getValue else null | |
authentication.getUserAuthentication.getPrincipal match { | |
case userDetails: AuthUserDetails => | |
domainManager.withTransaction{ txStatus => | |
OauthAccessToken.create(token.getValue, token.getExpiration, userDetails.userId, userDetails.ssoGroupId, | |
SerializationUtils.serialize(token), SerializationUtils.serialize(authentication), refreshToken) | |
} | |
case _ => throw new IllegalStateException("User Authentication Details is not an AuthUserDetails") | |
} | |
} | |
override def readAuthentication(token: OAuth2AccessToken): OAuth2Authentication[_ <: Authentication, _ <: Authentication] = { | |
domainManager.withTransaction({ txStatus => OauthAccessToken.lookup(token.getValue) }, true).asInstanceOf[Option[OauthAccessToken]] match { | |
case Some(accessToken) => | |
SerializationUtils.deserialize(accessToken.authentication) | |
case None => | |
log.info("No OauthAccessToken found for: "+token.getValue) | |
null | |
} | |
} | |
override def readAccessToken(tokenValue: String): OAuth2AccessToken = { | |
domainManager.withTransaction({ txStatus => OauthAccessToken.lookup(tokenValue) }, true).asInstanceOf[Option[OauthAccessToken]] match { | |
case Some(accessToken) => | |
SerializationUtils.deserialize(accessToken.token) | |
case None => | |
log.info("No OauthAccessToken found for: "+tokenValue) | |
null | |
} | |
} | |
override def removeAccessToken(tokenValue: String) = super.removeAccessToken(tokenValue) | |
def deleteAllAccessTokensByUserId(userId: Long) { | |
domainManager.withTransaction{ txStatus => OauthAccessToken.deleteAllByUserId(userId) } | |
} | |
override def afterPropertiesSet() { | |
log.debug("Adding token prune job to scheduler") | |
schedulerService.addJob(new SchedulerJob("pruneTokens", pruneTokens _, pruneSchedule)) | |
} | |
def pruneTokens() { | |
log.debug("Pruning expired tokens") | |
domainManager.withTransaction { txStatus => OauthAccessToken.deleteExpired(new Date) } | |
} | |
} | |
@Entity | |
@Table(name="oauth_code") | |
@NamedQueries(Array( | |
new NamedQuery(name="OauthVerificationCode.findAllByCode", query="from OauthAccessToken where code = :code"), | |
new NamedQuery(name="OauthVerificationCode.deleteExpired", query="delete from OauthAccessToken where expiration <= :asOf") | |
)) | |
class OauthVerificationCode { | |
@Id | |
var code: String = _ | |
@Index(name="expiration") | |
var expiration: Date = _ | |
@Lob | |
var authentication: Array[Byte] = _ | |
} | |
object OauthVerificationCode { | |
@Autowired var domainManager: DomainManager = null | |
AppConfig.inject(this) | |
def create(code: String, expiration: Date, authBytes: Array[Byte]) = { | |
val newCode = new OauthVerificationCode | |
newCode.code = code | |
newCode.expiration = expiration | |
newCode.authentication = authBytes.clone | |
domainManager.save(newCode) | |
newCode | |
} | |
def lookup(code: String): Option[OauthVerificationCode] = { | |
return domainManager.get[OauthVerificationCode](code) | |
} | |
def delete(ovcode: OauthVerificationCode) { | |
domainManager.delete(ovcode) | |
} | |
def deleteExpired(expiration: Date) = { | |
domainManager.executeUpdate("OauthAccessToken.deleteExpired", Map("asOf"->expiration)) | |
} | |
} | |
@Entity | |
@Table(name="oauth_access_token") | |
@NamedQueries(Array( | |
new NamedQuery(name="OauthAccessToken.findAllByUserId", query="from OauthAccessToken where user_id = :userId"), | |
new NamedQuery(name="OauthAccessToken.deleteAllByUserId", query="delete from OauthAccessToken where user_id = :userId"), | |
new NamedQuery(name="OauthAccessToken.deleteExpired", query="delete from OauthAccessToken where expiration <= :asOf") | |
)) | |
class OauthAccessToken { | |
@Id | |
var token_id: String = _ | |
@Index(name="expiration") | |
var expiration: Date = _ | |
@Index(name="user_id") | |
var user_id: Long = _ | |
var sso_group_id: Long = _ | |
@Lob | |
var token: Array[Byte] = _ | |
@Lob | |
var authentication: Array[Byte] = _ | |
var refresh_token: String = _ | |
} | |
object OauthAccessToken { | |
@Autowired var domainManager: DomainManager = null | |
AppConfig.inject(this) | |
def create(tokenId: String, expiration: Date, userId: Long, ssoGroupId: Long, tokenBytes: Array[Byte], authBytes: Array[Byte], refreshToken: String) = { | |
val newToken = new OauthAccessToken | |
newToken.token_id = tokenId | |
newToken.expiration = expiration | |
newToken.user_id = userId | |
newToken.sso_group_id = ssoGroupId | |
newToken.token = tokenBytes.clone | |
newToken.authentication = authBytes.clone | |
domainManager.save(newToken) | |
newToken | |
} | |
def lookup(tokenId: String): Option[OauthAccessToken] = { | |
return domainManager.get[OauthAccessToken](tokenId) | |
} | |
def findAllByUserId(userId: Long): Buffer[OauthAccessToken] = | |
domainManager.findAll[OauthAccessToken]("OauthAccessToken.findAllByUserId", Map("userId"->userId)) | |
def deleteAllByUserId(userId: Long) { | |
domainManager.executeUpdate("OauthAccessToken.deleteAllByUserId", Map("userId"->userId)) | |
} | |
def deleteExpired(expiration: Date) = { | |
domainManager.executeUpdate("OauthAccessToken.deleteExpired", Map("asOf"->expiration)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is in Scala, but how do I get it in Grails properly? I'm new to Grails and Java and I need to implement a OAuth provider which can handle the 2-legged flow.
Can you help me with that? Thanks.